Skip to content

Commit 9dd1e62

Browse files
authored
Merge pull request #290 from DeepWok/cx/comput_in_memory
Cx/comput in memory
2 parents 25a44f2 + 449d717 commit 9dd1e62

File tree

23 files changed

+1437
-3
lines changed

23 files changed

+1437
-3
lines changed

configs/pim/original.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
by: "type"
2+
conv2d:
3+
config:
4+
tile_type: "original"
5+
core_size: 16
6+
# rescale_dim: "vector"
7+
# x_quant_type: "e5m2"
8+
# weight_quant_type: "e4m3"
9+
10+
linear:
11+
config:
12+
tile_type: "original"
13+
core_size: 16
14+
# rescale_dim: "vector"
15+
# x_quant_type: "e5m2"
16+
# weight_quant_type: "e4m3"

configs/pim/pcm.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
by: "type"
2+
conv2d:
3+
config:
4+
tile_type: "pcm"
5+
core_size: 256
6+
num_bits: 8
7+
programming_noise: true
8+
read_noise: true
9+
ir_drop: true
10+
out_noise: true
11+
12+
linear:
13+
config:
14+
tile_type: "pcm"
15+
core_size: 256
16+
num_bits: 8
17+
programming_noise: true
18+
read_noise: false
19+
ir_drop: false
20+
out_noise: false

configs/pim/reram.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
by: "type"
2+
conv2d:
3+
config:
4+
tile_type: "reram"
5+
core_size: 256
6+
num_bits: 3
7+
noise_magnitude: 0.1
8+
9+
linear:
10+
config:
11+
tile_type: "reram"
12+
core_size: 256
13+
num_bits: 3
14+
noise_magnitude: 0.1

configs/pim/sram.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
by: "type"
2+
# conv2d:
3+
# config:
4+
# tile_type: "digital"
5+
# core_size: 16
6+
# rescale_dim: "vector"
7+
# x_quant_type: "int4"
8+
# weight_quant_type: "e4m3"
9+
10+
linear:
11+
config:
12+
tile_type: "digital"
13+
core_size: 64
14+
rescale_dim: "vector"
15+
x_quant_type: "e4m3"
16+
weight_quant_type: "e4m3"
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
chop.passes.module.transform.pim
2+
================================
3+
4+
5+
pim\_matmul\_transform\_pass
6+
-----------------------------
7+
8+
.. autofunction:: chop.passes.module.transforms.pim.pim_matmul_transform_pass

docs/source/modules/chop/passes_module.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ Summary of Mase Module Transform Passes
3535
* - :py:meth:`~chop.passes.module.transforms.quantize.quantize_module_transform_pass`
3636
- `test_module_quantize <https://github.com/DeepWok/mase/blob/main/test/passes/module/transforms/quantize/test_quantize_module.py>`_
3737
- Apply quantization transformation to the given nn.Module
38+
* - :py:meth:`~chop.passes.module.transforms.pim.pim_matmul_transform_pass`
39+
- `test_cim_transform_module_roberta <https://github.com/DeepWok/mase/blob/main/test/passes/module/transforms/cim/test_cim_transform_module_roberta.py>`_
40+
- Apply PIM transformation to the given nn.Module to simulate PIM hardware.
3841
* - :py:meth:`~chop.passes.module.transforms.onn.optical_transformer_module_transform_pass`
3942
- See :doc:`transform/onn`
4043
- Transform modules to Optical Neural Network (ONN) equivalents
@@ -44,4 +47,5 @@ Summary of Mase Module Transform Passes
4447
:caption: Full list of module-level transform passes
4548

4649
module_transform/quantization
47-
transform/onn
50+
module_transform/pim
51+
transform/onn

docs/source/modules/documentation/tutorials.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Advanced Topics
5959

6060
tutorials/advanced/tensorRT_quantization_tutorial
6161
tutorials/advanced/onnxrt_quantization_tutorial
62+
tutorials/advanced/pim_transform_tutorial
6263
tutorials/advanced/cli
6364

6465

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Advanced: PIM Simulation Tutorial
2+
3+
This tutorial demonstrates how to use the Mase framework to model and simulate the models' behaviour on PIM(process in memory) devices.
4+
5+
In this tutorial, we will focus on simulating the behaviour of PCM devices (phase change memory). For detailed information about the simulating parameters please refer to [Hardware-aware training for large-scale and diverse deep learning inference workloads using in-memory computing-based accelerators](https://www.nature.com/articles/s41467-023-40770-4)
6+
7+
About the detail explanation of the device simulation, we can see the full main documentation for more details. Here just show the simulation and evaluation pipeline with mase framework.
8+
9+
## Section 1. Evaluation with golden model
10+
In this section, we evaluate the baseline `RoBERTa` model on the `MNLI` (Multi-Genre Natural Language Inference) dataset. The MNLI dataset consists of pairs of sentences (a premise and a hypothesis) and the goal is to predict whether the premise entails, contradicts, or is neutral towards the hypothesis.
11+
12+
We use the `JeremiahZ/roberta-base-mnli` model, which is a RoBERTa-base model fine-tuned on MNLI.
13+
14+
```python
15+
from transformers import RobertaForSequenceClassification, AutoTokenizer
16+
from chop.dataset.nlp.text_entailment import TextEntailmentDatasetMNLI
17+
from torch.utils.data import DataLoader
18+
import torch
19+
from tqdm import tqdm
20+
21+
def evaluate(model, dataloader, device):
22+
model.eval()
23+
model.to(device)
24+
correct = 0
25+
total = 0
26+
with torch.no_grad():
27+
for batch in tqdm(dataloader, desc="Evaluating"):
28+
input_ids = batch["input_ids"].to(device)
29+
attention_mask = batch["attention_mask"].to(device)
30+
labels = batch["labels"].to(device).squeeze(-1)
31+
32+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
33+
logits = outputs.logits
34+
preds = torch.argmax(logits, dim=-1)
35+
36+
correct += (preds == labels).sum().item()
37+
total += labels.size(0)
38+
39+
return correct / total
40+
41+
pretrained = "JeremiahZ/roberta-base-mnli"
42+
tokenizer = AutoTokenizer.from_pretrained(pretrained)
43+
model = RobertaForSequenceClassification.from_pretrained(pretrained)
44+
45+
# Load a small subset of MNLI validation set for quick evaluation
46+
dataset = TextEntailmentDatasetMNLI(split="validation_matched", tokenizer=tokenizer, max_token_len=128, num_workers=4)
47+
dataloader = DataLoader(dataset, batch_size=16)
48+
49+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50+
accuracy = evaluate(model, dataloader, device)
51+
print(f"Golden Model Accuracy: {accuracy:.4f}")
52+
```
53+
54+
**Output:**
55+
```text
56+
Evaluating: 100%|██████████| 614/614 [00:35<00:00, 17.13it/s]Golden Model Accuracy: 0.8728
57+
```
58+
59+
## Section 2. Evaluation with transformed model
60+
Now we apply the `pim_matmul_transform_pass` to simulate the PIM hardware. We configure the transform to use PCM (Phase Change Memory) tiles with a core size of 256. We can enable various non-idealities like programming noise and read noise to see how they impact the model's performance.
61+
62+
### Configuration Details:
63+
- `tile_type`: The type of PIM technology to simulate (e.g., 'pcm', 'sram', 'reram').
64+
- `core_size`: The size of the crossbar array (e.g., 256x256).
65+
- `num_bits`: The number of bits used for weights and activations.
66+
- `programming_noise`: Simulates variability during the programming of PIM cells.
67+
- `read_noise`: Simulates noise during the read-out process.
68+
- `ir_drop`: Simulates voltage drops along the crossbar lines.
69+
- `out_noise`: Simulates noise at the output of the crossbar.
70+
71+
```python
72+
from chop.passes.module.transforms import pim_matmul_transform_pass
73+
import copy
74+
75+
q_config = {
76+
"by": "type",
77+
"linear": {
78+
"config": {
79+
"tile_type": "pcm",
80+
"core_size": 256,
81+
"num_bits": 8,
82+
"programming_noise": True,
83+
"read_noise": True,
84+
"ir_drop": True,
85+
"out_noise": True,
86+
}
87+
},
88+
}
89+
90+
# Apply the transform pass
91+
transformed_model = copy.deepcopy(model)
92+
qmodel, _ = pim_matmul_transform_pass(transformed_model, q_config)
93+
94+
q_accuracy = evaluate(qmodel, dataloader, device)
95+
print(f"Transformed Model Accuracy (with PIM noise): {q_accuracy:.4f}")
96+
```
97+
98+
**Output:**
99+
```text
100+
Evaluating: 100%|██████████| 614/614 [15:49<00:00, 1.55s/it]Transformed Model Accuracy (with PIM noise): 0.3293
101+
```
102+
103+
## Conclusion
104+
In this tutorial, we demonstrated how to:
105+
1. Load a pretrained RoBERTa model and evaluate it on the MNLI dataset.
106+
2. Use `pim_matmul_transform_pass` to simulate hardware non-idealities for PIM devices.
107+
3. Evaluate the impact of these non-idealities on model accuracy.

0 commit comments

Comments
 (0)