Skip to content

Commit dfe48d0

Browse files
Modify CCT Test
1 parent 531fa57 commit dfe48d0

File tree

1 file changed

+77
-4
lines changed

1 file changed

+77
-4
lines changed

Tests/TestCCT.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#
55
# Federico Brancasi <[email protected]>
66

7+
from typing import Tuple
8+
79
import brevitas.nn as qnn
810
import pytest
911
import torch
@@ -18,9 +20,61 @@
1820
Uint8ActPerTensorFloat,
1921
)
2022

23+
from DeepQuant.Transforms.Executor import TransformationExecutor
24+
from DeepQuant.Transforms.Transformations import LinearTransformation, MHATransformation
25+
from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc
26+
from DeepQuant.Utils.CustomTracer import QuantTracer, customBrevitasTrace
27+
from DeepQuant.Utils.GraphPrinter import GraphModulePrinter
2128
from Tests.Models.CCT import cct_2_3x2_32
2229

2330

31+
def injectCustomForwards(
32+
model: nn.Module,
33+
exampleInput: torch.Tensor,
34+
referenceOutput: torch.Tensor,
35+
debug: bool = False,
36+
checkEquivalence: bool = False,
37+
) -> Tuple[nn.Module, torch.Tensor]:
38+
"""Custom inject function for CCT that excludes ActivationTransformation."""
39+
printer = GraphModulePrinter()
40+
41+
tracer = QuantTracer(debug=debug)
42+
43+
transformations = [
44+
MHATransformation(),
45+
LinearTransformation(),
46+
# ActivationTransformation(), # FBRANCASI: Commented out for CCT compatibility
47+
]
48+
49+
executor = TransformationExecutor(transformations, debug=debug, tracer=tracer)
50+
transformedModel = executor.execute(model, exampleInput)
51+
52+
fxModel = customBrevitasTrace(
53+
root=transformedModel,
54+
tracer=tracer,
55+
)
56+
fxModel.recompile()
57+
58+
with torch.no_grad():
59+
output = fxModel(exampleInput)
60+
61+
if checkEquivalence:
62+
if torch.allclose(referenceOutput, output, atol=1e-5):
63+
if debug:
64+
print(cc.success("Injection of New Modules: output is consistent"))
65+
else:
66+
raise RuntimeError(
67+
cc.error("Injection of New Modules changed the output significantly")
68+
)
69+
70+
if debug:
71+
print(cc.header("2. Network after Injection of New Modules"))
72+
printer.printTabular(fxModel)
73+
print()
74+
75+
return fxModel, output
76+
77+
2478
def prepareCCT(model) -> nn.Module:
2579
"""
2680
Prepare a quantized CCT model for testing with export support.
@@ -84,8 +138,6 @@ def prepareCCT(model) -> nn.Module:
84138

85139
quant_name = f"{node.name}_reshape_fix"
86140
model.add_module(quant_name, quant_identity)
87-
# mark this QuantIdentity as “reshape fix”
88-
quant_identity._is_reshape_fix = True
89141

90142
with model.graph.inserting_after(node):
91143
quant_node = model.graph.call_module(quant_name, args=(node,))
@@ -181,6 +233,27 @@ def deepQuantTestCCT():
181233
print(f"Output shape: {output.shape}")
182234
print(f"Output range: [{output.min().item():.3f}, {output.max().item():.3f}]")
183235

184-
from DeepQuant import brevitasToTrueQuant
236+
# FBRANCASI: Override the injectCustomForwards function in the module before DeepQuant.Export imports it
237+
import DeepQuant.Pipeline.Injection as injection_module
238+
239+
# FBRANCASI: Store original function
240+
original_inject = injection_module.injectCustomForwards
241+
242+
# FBRANCASI: Override with our custom function
243+
injection_module.injectCustomForwards = injectCustomForwards
244+
245+
# FBRANCASI: Force reload of Export module to pick up the override
246+
import importlib
247+
248+
import DeepQuant.Export
249+
250+
importlib.reload(DeepQuant.Export)
251+
252+
try:
253+
from DeepQuant.Export import brevitasToTrueQuant
185254

186-
brevitasToTrueQuant(quantizedModel, sampleInput, debug=True)
255+
brevitasToTrueQuant(quantizedModel, sampleInput, debug=True)
256+
finally:
257+
# FBRANCASI: Restore original function and reload Export module again
258+
injection_module.injectCustomForwards = original_inject
259+
importlib.reload(DeepQuant.Export)

0 commit comments

Comments
 (0)