Skip to content

Commit bcc54eb

Browse files
committed
add a user guide
1 parent 28de334 commit bcc54eb

File tree

2 files changed

+147
-4
lines changed

2 files changed

+147
-4
lines changed

README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
This repository provides the implementation of **LLMCompass** from the following papers:
66

7-
**LLMCompass: Enabling Efficient Hardware Design for Large Language Model Inference**
7+
[**LLMCompass: Enabling Efficient Hardware Design for Large Language Model Inference**](https://parallel.princeton.edu/papers/isca24_llmcompass.pdf)
88

99
*Hengrui Zhang, August Ning, Rohan Baskar Prabhakar, David Wentzlaff*
1010

1111

12-
To appear in the Proceedings of the 51st Annual International Symposium on Computer Architecture:
12+
In the Proceedings of the 51st Annual International Symposium on Computer Architecture:
1313

1414
```
1515
@inproceedings{LLMCompass,
@@ -52,7 +52,7 @@ A Dockerfile has been provided (`./Dockerfile`), including all the software depe
5252

5353
A docker image has been provided [here](https://github.com/HenryChang213/LLMCompass_ISCA_AE_docker).
5454

55-
## Experiment workflow
55+
## AE Experiment workflow
5656
```
5757
# Figure 5 (around 100 min)
5858
$ cd ae/figure5
@@ -87,9 +87,14 @@ $ cd ae/figure12
8787
$ bash run_figure12.sh
8888
```
8989

90-
## Expected result
90+
## AE Expected result
9191

9292
After running each script above, the corresponding figures
9393
will be generated under the corresponding directory as suggested by its name.
9494

9595
For comparison, a copy of the expected results can be found in `ae\expected_results`
96+
97+
98+
## User Guide
99+
100+
A guide on "How to Run a LLMCompass Simulation" is shown [here](./docs/run.md).

docs/run.md

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# User Guide: How to Run a LLMCompass Simulation
2+
3+
## Step 1: Build a Hardware Configuration
4+
5+
Follow the [NVIDIA GA100 example](../configs/GA100.json). This is a 4-GA100 node connected with NVLinks.
6+
7+
### Explanations on the Knobs
8+
Most of the attributes are self-explained:
9+
10+
```json
11+
{
12+
"name": "NVIDIA A100(80GB)x4",
13+
"device_count": 4, # how many devices in a node
14+
"interconnect": {
15+
"link": {
16+
"name": "NVLink3",
17+
"bandwidth_per_direction_byte": 25e9,
18+
"bandwidth_both_directions_byte": 50e9,
19+
"latency_second": 8.92e-6,
20+
"flit_size_byte": 16,
21+
"header_size_byte": 16,
22+
"max_payload_size_byte": 256
23+
},
24+
"link_count_per_device": 12,
25+
"topology": "FC" # currently support FC (fully-connected) and RING
26+
},
27+
"device": {
28+
"frequency_Hz": 1410e6,
29+
"compute_chiplet_count": 1,
30+
"compute_chiplet": {
31+
"physical_core_count": 128, # used for area model
32+
"core_count": 128, # used for performance model
33+
"process_node": "7nm", # currently support 7nm, 6nm, 5nm
34+
"core": {
35+
"sublane_count": 4,
36+
"systolic_array": {
37+
"array_width": 16,
38+
"array_height": 16,
39+
"data_type": "fp16",
40+
"mac_per_cycle": 1
41+
},
42+
"vector_unit": {
43+
"vector_width": 32,
44+
"flop_per_cycle": 4, # 32*4=128 flops per cycle per vector unit
45+
"data_type": "fp16",
46+
"int32_count": 16, # the number of int32 ALUs, used for area model
47+
"fp16_count": 0,
48+
"fp32_count": 16,
49+
"fp64_count": 8
50+
},
51+
"register_file": {
52+
"num_reg_files": 1,
53+
"num_registers": 16384,
54+
"register_bitwidth":32,
55+
"num_rdwr_ports":4
56+
},
57+
"SRAM_KB": 192
58+
}
59+
},
60+
"memory_protocol": "HBM2e",
61+
"_memory_protocol_list": [
62+
"HBM2e",
63+
"DDR4",
64+
"DDR5",
65+
"PCIe4",
66+
"PCIe5"
67+
],
68+
"io": {
69+
"process_node": "7nm",
70+
"global_buffer_MB": 48,
71+
"physical_global_buffer_MB": 48,
72+
"global_buffer_bandwidth_per_cycle_byte": 5120,
73+
"memory_channel_physical_count": 6, # used for area model
74+
"memory_channel_active_count": 5, # used for performance model
75+
"pin_count_per_channel": 1024,
76+
"bandwidth_per_pin_bit": 3.2e9
77+
},
78+
"memory": {
79+
"total_capacity_GB": 80
80+
}
81+
}
82+
}
83+
84+
```
85+
86+
## Step 2: Build a LLM Computational Graph
87+
88+
Transformer blocks have been provided as in [`transformer.py`](../software_model/transformer.py), including Initial Computation (also called Prefill or Context stage) and Auto Regression (also called Decoding or Generation stage), with Tensor Parallelism support (automatically turned of if the system only has 1 device).
89+
90+
The user needs to provide these parameter:
91+
* `d_model`: the hidden dimension, 12288 for GPT3
92+
* `n_heads`: the number of heads, 96 for GPT3
93+
* `device_count`: tensor parallelism
94+
* `data_type`: `int8`, `fp16`, or `fp32`
95+
96+
### Build Your Own LLM
97+
98+
The user can also build their own computational graph following the [`transformer.py`](../software_model/transformer.py) example using provided operators: [`matmul`](../software_model/matmul.py), [`softmax`](../software_model/softmax.py), [`layernorm`](../software_model/layernorm.py), [`gelu`](../software_model/gelu.py), and [`allreduce`](../software_model/communication_primitives.py).
99+
100+
The user needs to define a new `class` by inheriting `Operator` the class and configure these fields:
101+
* `__init__`: define the needed operators in the initial function
102+
* `__call__`: build the computational graph. The shape of Tensors will be automatically calculated and used for simulation.
103+
* `compile_and_simulate`: simulate all the operators and get the total latency as well as other runtimes.
104+
* `roofline_model` (optional): a roofline model analysis.
105+
* `run_on_gpu` (optional): run the computational graph on real-world GPUs with PyTorch.
106+
107+
## Step 3: Run a LLMCompass Simulation
108+
109+
First, read the hardware configuration and parse it to LLMCompass:
110+
```python
111+
from design_space_exploration.dse import template_to_system, read_architecture_template
112+
113+
specs = read_architecture_template("PATH/TO/YOUR/JSON")
114+
system = template_to_system(specs)
115+
116+
```
117+
118+
Next, initiate and instantiate an LLM as in this example:
119+
```python
120+
model_auto_regression = TransformerBlockAutoRegressionTP(
121+
d_model=12288,
122+
n_heads=96,
123+
device_count=1,
124+
data_type=data_type_dict["fp16"],
125+
)
126+
_ = model_auto_regression(
127+
Tensor([bs, 1, 12288], data_type_dict["fp16"]),
128+
seq_len,
129+
)
130+
131+
```
132+
133+
Finally, run the simulation
134+
```
135+
auto_regression_latency_simulated = model_auto_regression.compile_and_simulate(
136+
system, "heuristic-GPU"
137+
)
138+
```

0 commit comments

Comments
 (0)