Skip to content

Commit a06cb52

Browse files
committed
add basic readme
1 parent e57992b commit a06cb52

File tree

1 file changed

+129
-0
lines changed
  • src/pruna/algorithms/quantization/backends/ganq

1 file changed

+129
-0
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#### Guide to use GANQ quantization
2+
3+
**Quantize a model**
4+
5+
```python
6+
import torch
7+
from transformers import AutoModelForCausalLM, AutoTokenizer
8+
9+
from pruna.config.smash_config import SmashConfig
10+
from pruna.data.pruna_datamodule import PrunaDataModule
11+
12+
13+
import torch
14+
from transformers import AutoModelForCausalLM
15+
16+
import torch
17+
from pruna.algorithms.quantization.ganq import GANQQuantizer
18+
19+
# -------------------------------------------------------------------------
20+
# 1. Load model and tokenizer
21+
# -------------------------------------------------------------------------
22+
model_name = "HuggingFaceTB/SmolLM2-135M"
23+
tokenizer = AutoTokenizer.from_pretrained(model_name)
24+
if tokenizer.pad_token is None:
25+
tokenizer.pad_token = tokenizer.eos_token
26+
model = AutoModelForCausalLM.from_pretrained(
27+
model_name, torch_dtype=torch.float16, device_map="auto"
28+
)
29+
model.eval()
30+
31+
# -------------------------------------------------------------------------
32+
# 2. Build SmashConfig for Pruna Quantizer
33+
# -------------------------------------------------------------------------
34+
smash_config = SmashConfig(
35+
batch_size=4,
36+
device="cuda" if torch.cuda.is_available() else "cpu",
37+
cache_dir_prefix="./cache_ganq",
38+
)
39+
40+
# Add tokenizer
41+
smash_config.add_tokenizer(tokenizer)
42+
43+
# Use Pruna's built-in WikiText dataset (handles train/val/test splits automatically)
44+
data_module = PrunaDataModule.from_string(
45+
"WikiText",
46+
tokenizer=tokenizer,
47+
collate_fn_args=dict(max_seq_len=256),
48+
)
49+
data_module.limit_datasets(32) # Limit to 32 examples per split for quick testing
50+
smash_config.add_data(data_module)
51+
52+
# Configure quantizer parameters
53+
smash_config.load_dict(
54+
{
55+
"quantizer": "ganq",
56+
"ganq_weight_bits": 4,
57+
"ganq_max_epoch": 10,
58+
"ganq_pre_process": True,
59+
}
60+
)
61+
62+
# -------------------------------------------------------------------------
63+
# 4. Run Quantization
64+
# -------------------------------------------------------------------------
65+
quantizer = GANQQuantizer()
66+
67+
quantized_model = quantizer._apply(model, smash_config)
68+
69+
# -------------------------------------------------------------------------
70+
# 5. Save the quantized model
71+
# -------------------------------------------------------------------------
72+
quantized_model.save_pretrained("./ganq_quantized_smollm")
73+
tokenizer.save_pretrained("./ganq_quantized_smollm")
74+
75+
print("✅ GANQ quantization complete and model saved at ./ganq_quantized_smollm")
76+
77+
78+
def model_size_in_mb(model):
79+
param_size = 0
80+
for param in model.parameters():
81+
param_size += param.nelement() * param.element_size()
82+
buffer_size = 0
83+
for buffer in model.buffers():
84+
buffer_size += buffer.nelement() * buffer.element_size()
85+
size_all_mb = (param_size + buffer_size) / 1024**2
86+
return size_all_mb
87+
88+
89+
original_size = model_size_in_mb(model)
90+
quantized_size = model_size_in_mb(quantized_model)
91+
print(f"Original model size: {original_size:.2f} MB")
92+
print(f"Quantized model size: {quantized_size:.2f} MB")
93+
94+
```
95+
96+
97+
**Verify if quantization worked**
98+
99+
The logic here is that since GANQ uses a codebook of size (m, L) for a weight matrix for size (m,n) where L is 2^k (k = number of bits), each row in the weight matrix W should only contain values from the corressponding row in the codebook, where selection is driven by the one hot matrix S. So number of unique values in each row of W should be exactly L.
100+
101+
```python
102+
from transformers import AutoModelForCausalLM, AutoTokenizer
103+
import torch
104+
105+
model_name = "HuggingFaceTB/SmolLM2-135M"
106+
tokenizer = AutoTokenizer.from_pretrained(model_name)
107+
if tokenizer.pad_token is None:
108+
tokenizer.pad_token = tokenizer.eos_token
109+
model = AutoModelForCausalLM.from_pretrained(
110+
model_name, torch_dtype=torch.float16, device_map="auto"
111+
)
112+
model.eval()
113+
114+
model_q = AutoModelForCausalLM.from_pretrained(
115+
"ganq_quantized_smollm"
116+
)
117+
118+
def verify_unique_entries_in_row(layer, row_idx=0):
119+
Wq = layer.self_attn.q_proj.weight.data
120+
unique_entries = torch.unique(Wq[row_idx])
121+
print(f"Number of unique entries in row {row_idx}: {unique_entries.numel()}")
122+
123+
verify_unique_entries_in_row(model_q.model.layers[1], row_idx=1)
124+
verify_unique_entries_in_row(model.model.layers[1], row_idx=1)
125+
126+
# In my experiments, it gave this:
127+
# Number of unique entries in row 1: 16 (since I used 4-bit quantization)
128+
# Number of unique entries in row 1: 471
129+
```

0 commit comments

Comments
 (0)