Skip to content

Commit 01a60c5

Browse files
Refactor CCT Test
1 parent e2bc1b4 commit 01a60c5

File tree

1 file changed

+13
-50
lines changed

1 file changed

+13
-50
lines changed

Tests/TestCCT.py

Lines changed: 13 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Federico Brancasi <[email protected]>
6+
67
import brevitas.nn as qnn
78
import pytest
89
import torch
910
import torch.nn as nn
11+
from brevitas.fx.brevitas_tracer import symbolic_trace
1012
from brevitas.graph.quantize import preprocess_for_quantize, quantize
13+
from brevitas.graph.utils import replace_all_uses_except
1114
from brevitas.quant import (
1215
Int8ActPerTensorFloat,
1316
Int8WeightPerTensorFloat,
@@ -22,98 +25,76 @@ def prepareCCT(model) -> nn.Module:
2225
"""
2326
Prepare a quantized CCT model for testing with export support.
2427
"""
25-
import operator
26-
27-
from brevitas.fx.brevitas_tracer import symbolic_trace
28-
from brevitas.graph.utils import replace_all_uses_except
2928

30-
# First trace the model
3129
if not hasattr(model, "graph"):
3230
model = symbolic_trace(model)
3331

3432
print("=== FIXING QUANTIZATION ISSUES ===")
3533

36-
# Collect all modifications first, then apply them
3734
transpose_fixes = []
3835
qkv_fixes = []
3936

40-
# Fix 1: Find transpose -> add patterns
37+
# FBRANCASI: Fix 1, Find transpose -> add patterns
4138
for node in model.graph.nodes:
4239
if node.op == "call_method" and node.target == "transpose":
4340
for user in node.users:
4441
if (
4542
"add" in user.name
46-
or user.target
47-
in [
48-
torch.add,
49-
operator.add,
50-
operator.iadd,
51-
operator.__add__,
52-
operator.__iadd__,
53-
]
43+
or user.target in [torch.add]
5444
or (user.op == "call_method" and user.target in ["add", "add_"])
5545
):
5646
transpose_fixes.append((node, user))
5747
break
5848

59-
# Fix 2: Find QKV -> reshape patterns
49+
# FBRANCASI: Fix 2, Find QKV -> reshape patterns
6050
for node in model.graph.nodes:
6151
if node.op == "call_module" and "qkv" in node.target:
6252
for user in node.users:
6353
if user.op == "call_method" and user.target == "reshape":
6454
qkv_fixes.append((node, user))
6555
break
6656

67-
# Apply transpose fixes
57+
# FBRANCASI: Apply transpose fixes
6858
print(f"\nApplying {len(transpose_fixes)} transpose fixes...")
6959
for node, user in transpose_fixes:
7060
print(f" Fixing: {node.name} -> {user.name}")
7161

72-
# Create a QuantIdentity
7362
quant_identity = qnn.QuantIdentity(
7463
act_quant=Int8ActPerTensorFloat, return_quant_tensor=True
7564
)
7665

77-
# Add to model
7866
quant_name = f"{node.name}_quant_fix"
7967
model.add_module(quant_name, quant_identity)
8068

81-
# Insert in the graph after transpose
8269
with model.graph.inserting_after(node):
8370
quant_node = model.graph.call_module(quant_name, args=(node,))
8471

8572
# Replace uses
8673
replace_all_uses_except(node, quant_node, [quant_node])
8774

88-
# Apply QKV fixes
75+
# FBRANCASI: Apply QKV fixes
8976
print(f"\nApplying {len(qkv_fixes)} QKV fixes...")
9077
for node, reshape_user in qkv_fixes:
9178
print(f" Fixing: {node.name} -> {reshape_user.name}")
9279

93-
# Create a QuantIdentity to handle the tensor properly
9480
quant_identity = qnn.QuantIdentity(
9581
act_quant=Int8ActPerTensorFloat,
96-
return_quant_tensor=False, # Important: return regular tensor for reshape
82+
return_quant_tensor=False, # FBRANCASI: return regular tensor for reshape
9783
)
9884

99-
# Add to model
10085
quant_name = f"{node.name}_reshape_fix"
10186
model.add_module(quant_name, quant_identity)
10287

103-
# Insert in the graph between qkv and reshape
10488
with model.graph.inserting_after(node):
10589
quant_node = model.graph.call_module(quant_name, args=(node,))
10690

107-
# Update reshape to use the quant_node output
10891
reshape_user.update_arg(0, quant_node)
10992

110-
# Recompile graph only once after all modifications
11193
model.recompile()
11294
model.graph.lint()
11395

114-
print("\n=== Graph modifications complete ===")
96+
print("\n=== GRAPH MODIFICATION COMPLETE ===")
11597

116-
# Define quantization mappings
11798
computeLayerMap = {
11899
nn.Conv2d: (
119100
qnn.QuantConv2d,
@@ -128,21 +109,6 @@ def prepareCCT(model) -> nn.Module:
128109
"weight_bit_width": 4,
129110
},
130111
),
131-
nn.MultiheadAttention: (
132-
qnn.QuantMultiheadAttention,
133-
{
134-
"in_proj_weight_quant": Int8WeightPerTensorFloat,
135-
"in_proj_bias_quant": Int32Bias,
136-
"attn_output_weights_quant": Uint8ActPerTensorFloat,
137-
"q_scaled_quant": Int8ActPerTensorFloat,
138-
"k_transposed_quant": Int8ActPerTensorFloat,
139-
"v_quant": Int8ActPerTensorFloat,
140-
"out_proj_input_quant": Int8ActPerTensorFloat,
141-
"out_proj_weight_quant": Int8WeightPerTensorFloat,
142-
"out_proj_bias_quant": Int32Bias,
143-
"return_quant_tensor": True,
144-
},
145-
),
146112
nn.Linear: (
147113
qnn.QuantLinear,
148114
{
@@ -195,15 +161,13 @@ def prepareCCT(model) -> nn.Module:
195161
),
196162
}
197163

198-
# Preprocess model
199164
model = preprocess_for_quantize(
200165
model,
201166
equalize_iters=10,
202167
equalize_scale_computation="range",
203-
trace_model=False, # Already traced
168+
trace_model=False, # FBRANCASI: Already traced
204169
)
205170

206-
# Quantize model
207171
quantizedModel = quantize(
208172
graph_model=model,
209173
compute_layer_map=computeLayerMap,
@@ -219,15 +183,14 @@ def deepQuantTestCCT():
219183
torch.manual_seed(42)
220184
sampleInput = torch.randn(1, 3, 32, 32)
221185

222-
model = cct_2_3x2_32() # 2 encoder layers, kernel dim 3, 2 convs, 32x32
186+
model = cct_2_3x2_32() # FBRANCASI: 2 encoder layers, kernel dim 3, 2 convs, 32x32
223187
model.eval()
224188

225189
print(model)
226190

227191
quantizedModel = prepareCCT(model)
228192

229-
# Test the quantized model
230-
print(f"\nTesting with input shape: {sampleInput.shape}")
193+
print(f"\nTesting the Quantized Model with input shape: {sampleInput.shape}")
231194
with torch.no_grad():
232195
output = quantizedModel(sampleInput)
233196
print(f"Output shape: {output.shape}")

0 commit comments

Comments
 (0)