33# SPDX-License-Identifier: Apache-2.0
44#
55# Federico Brancasi <[email protected] > 6+
67import brevitas .nn as qnn
78import pytest
89import torch
910import torch .nn as nn
11+ from brevitas .fx .brevitas_tracer import symbolic_trace
1012from brevitas .graph .quantize import preprocess_for_quantize , quantize
13+ from brevitas .graph .utils import replace_all_uses_except
1114from brevitas .quant import (
1215 Int8ActPerTensorFloat ,
1316 Int8WeightPerTensorFloat ,
@@ -22,98 +25,76 @@ def prepareCCT(model) -> nn.Module:
2225 """
2326 Prepare a quantized CCT model for testing with export support.
2427 """
25- import operator
26-
27- from brevitas .fx .brevitas_tracer import symbolic_trace
28- from brevitas .graph .utils import replace_all_uses_except
2928
30- # First trace the model
3129 if not hasattr (model , "graph" ):
3230 model = symbolic_trace (model )
3331
3432 print ("=== FIXING QUANTIZATION ISSUES ===" )
3533
36- # Collect all modifications first, then apply them
3734 transpose_fixes = []
3835 qkv_fixes = []
3936
40- # Fix 1: Find transpose -> add patterns
37+ # FBRANCASI: Fix 1, Find transpose -> add patterns
4138 for node in model .graph .nodes :
4239 if node .op == "call_method" and node .target == "transpose" :
4340 for user in node .users :
4441 if (
4542 "add" in user .name
46- or user .target
47- in [
48- torch .add ,
49- operator .add ,
50- operator .iadd ,
51- operator .__add__ ,
52- operator .__iadd__ ,
53- ]
43+ or user .target in [torch .add ]
5444 or (user .op == "call_method" and user .target in ["add" , "add_" ])
5545 ):
5646 transpose_fixes .append ((node , user ))
5747 break
5848
59- # Fix 2: Find QKV -> reshape patterns
49+ # FBRANCASI: Fix 2, Find QKV -> reshape patterns
6050 for node in model .graph .nodes :
6151 if node .op == "call_module" and "qkv" in node .target :
6252 for user in node .users :
6353 if user .op == "call_method" and user .target == "reshape" :
6454 qkv_fixes .append ((node , user ))
6555 break
6656
67- # Apply transpose fixes
57+ # FBRANCASI: Apply transpose fixes
6858 print (f"\n Applying { len (transpose_fixes )} transpose fixes..." )
6959 for node , user in transpose_fixes :
7060 print (f" Fixing: { node .name } -> { user .name } " )
7161
72- # Create a QuantIdentity
7362 quant_identity = qnn .QuantIdentity (
7463 act_quant = Int8ActPerTensorFloat , return_quant_tensor = True
7564 )
7665
77- # Add to model
7866 quant_name = f"{ node .name } _quant_fix"
7967 model .add_module (quant_name , quant_identity )
8068
81- # Insert in the graph after transpose
8269 with model .graph .inserting_after (node ):
8370 quant_node = model .graph .call_module (quant_name , args = (node ,))
8471
8572 # Replace uses
8673 replace_all_uses_except (node , quant_node , [quant_node ])
8774
88- # Apply QKV fixes
75+ # FBRANCASI: Apply QKV fixes
8976 print (f"\n Applying { len (qkv_fixes )} QKV fixes..." )
9077 for node , reshape_user in qkv_fixes :
9178 print (f" Fixing: { node .name } -> { reshape_user .name } " )
9279
93- # Create a QuantIdentity to handle the tensor properly
9480 quant_identity = qnn .QuantIdentity (
9581 act_quant = Int8ActPerTensorFloat ,
96- return_quant_tensor = False , # Important : return regular tensor for reshape
82+ return_quant_tensor = False , # FBRANCASI : return regular tensor for reshape
9783 )
9884
99- # Add to model
10085 quant_name = f"{ node .name } _reshape_fix"
10186 model .add_module (quant_name , quant_identity )
10287
103- # Insert in the graph between qkv and reshape
10488 with model .graph .inserting_after (node ):
10589 quant_node = model .graph .call_module (quant_name , args = (node ,))
10690
107- # Update reshape to use the quant_node output
10891 reshape_user .update_arg (0 , quant_node )
10992
110- # Recompile graph only once after all modifications
11193 model .recompile ()
11294 model .graph .lint ()
11395
114- print ("\n === Graph modifications complete ===" )
96+ print ("\n === GRAPH MODIFICATION COMPLETE ===" )
11597
116- # Define quantization mappings
11798 computeLayerMap = {
11899 nn .Conv2d : (
119100 qnn .QuantConv2d ,
@@ -128,21 +109,6 @@ def prepareCCT(model) -> nn.Module:
128109 "weight_bit_width" : 4 ,
129110 },
130111 ),
131- nn .MultiheadAttention : (
132- qnn .QuantMultiheadAttention ,
133- {
134- "in_proj_weight_quant" : Int8WeightPerTensorFloat ,
135- "in_proj_bias_quant" : Int32Bias ,
136- "attn_output_weights_quant" : Uint8ActPerTensorFloat ,
137- "q_scaled_quant" : Int8ActPerTensorFloat ,
138- "k_transposed_quant" : Int8ActPerTensorFloat ,
139- "v_quant" : Int8ActPerTensorFloat ,
140- "out_proj_input_quant" : Int8ActPerTensorFloat ,
141- "out_proj_weight_quant" : Int8WeightPerTensorFloat ,
142- "out_proj_bias_quant" : Int32Bias ,
143- "return_quant_tensor" : True ,
144- },
145- ),
146112 nn .Linear : (
147113 qnn .QuantLinear ,
148114 {
@@ -195,15 +161,13 @@ def prepareCCT(model) -> nn.Module:
195161 ),
196162 }
197163
198- # Preprocess model
199164 model = preprocess_for_quantize (
200165 model ,
201166 equalize_iters = 10 ,
202167 equalize_scale_computation = "range" ,
203- trace_model = False , # Already traced
168+ trace_model = False , # FBRANCASI: Already traced
204169 )
205170
206- # Quantize model
207171 quantizedModel = quantize (
208172 graph_model = model ,
209173 compute_layer_map = computeLayerMap ,
@@ -219,15 +183,14 @@ def deepQuantTestCCT():
219183 torch .manual_seed (42 )
220184 sampleInput = torch .randn (1 , 3 , 32 , 32 )
221185
222- model = cct_2_3x2_32 () # 2 encoder layers, kernel dim 3, 2 convs, 32x32
186+ model = cct_2_3x2_32 () # FBRANCASI: 2 encoder layers, kernel dim 3, 2 convs, 32x32
223187 model .eval ()
224188
225189 print (model )
226190
227191 quantizedModel = prepareCCT (model )
228192
229- # Test the quantized model
230- print (f"\n Testing with input shape: { sampleInput .shape } " )
193+ print (f"\n Testing the Quantized Model with input shape: { sampleInput .shape } " )
231194 with torch .no_grad ():
232195 output = quantizedModel (sampleInput )
233196 print (f"Output shape: { output .shape } " )
0 commit comments