Skip to content

Commit 29d7c0c

Browse files
Modify CCT Test Pretrained
1 parent dfe48d0 commit 29d7c0c

File tree

1 file changed

+79
-4
lines changed

1 file changed

+79
-4
lines changed

Tests/TestCCTPretrained.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#
55
# Federico Brancasi <[email protected]>
66

7+
from typing import Tuple
8+
79
import brevitas.nn as qnn
810
import pytest
911
import torch
@@ -23,10 +25,61 @@
2325
from torch.utils.data import DataLoader, Subset
2426
from tqdm import tqdm
2527

26-
from DeepQuant import brevitasToTrueQuant
28+
from DeepQuant.Transforms.Executor import TransformationExecutor
29+
from DeepQuant.Transforms.Transformations import LinearTransformation, MHATransformation
30+
from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc
31+
from DeepQuant.Utils.CustomTracer import QuantTracer, customBrevitasTrace
32+
from DeepQuant.Utils.GraphPrinter import GraphModulePrinter
2733
from Tests.Models.CCT import cct_2_3x2_32
2834

2935

36+
def injectCustomForwards(
37+
model: nn.Module,
38+
exampleInput: torch.Tensor,
39+
referenceOutput: torch.Tensor,
40+
debug: bool = False,
41+
checkEquivalence: bool = False,
42+
) -> Tuple[nn.Module, torch.Tensor]:
43+
"""Custom inject function for CCT that excludes ActivationTransformation."""
44+
printer = GraphModulePrinter()
45+
46+
tracer = QuantTracer(debug=debug)
47+
48+
transformations = [
49+
MHATransformation(),
50+
LinearTransformation(),
51+
# ActivationTransformation(), # FBRANCASI: Commented out for CCT compatibility
52+
]
53+
54+
executor = TransformationExecutor(transformations, debug=debug, tracer=tracer)
55+
transformedModel = executor.execute(model, exampleInput)
56+
57+
fxModel = customBrevitasTrace(
58+
root=transformedModel,
59+
tracer=tracer,
60+
)
61+
fxModel.recompile()
62+
63+
with torch.no_grad():
64+
output = fxModel(exampleInput)
65+
66+
if checkEquivalence:
67+
if torch.allclose(referenceOutput, output, atol=1e-5):
68+
if debug:
69+
print(cc.success("Injection of New Modules: output is consistent"))
70+
else:
71+
raise RuntimeError(
72+
cc.error("Injection of New Modules changed the output significantly")
73+
)
74+
75+
if debug:
76+
print(cc.header("2. Network after Injection of New Modules"))
77+
printer.printTabular(fxModel)
78+
print()
79+
80+
return fxModel, output
81+
82+
3083
def evaluateModel(model, dataLoader, evalDevice, name="Model"):
3184
model.eval()
3285
correct = 0
@@ -133,8 +186,6 @@ def prepareFQCCT(model) -> nn.Module:
133186

134187
quant_name = f"{node.name}_reshape_fix"
135188
model.add_module(quant_name, quant_identity)
136-
# mark this QuantIdentity as “reshape fix”
137-
quant_identity._is_reshape_fix = True
138189

139190
with model.graph.inserting_after(node):
140191
quant_node = model.graph.call_module(quant_name, args=(node,))
@@ -265,7 +316,31 @@ def deepQuantTestCCT():
265316
FQAccuracy = evaluateModel(FQModel, valLoader, device, "FQ CCT-2")
266317

267318
sampleInput = torch.randn(1, 3, 32, 32).to("cpu")
268-
TQModel = brevitasToTrueQuant(FQModel, sampleInput, debug=True)
319+
320+
# FBRANCASI: Override the injectCustomForwards function in the module before DeepQuant.Export imports it
321+
import DeepQuant.Pipeline.Injection as injection_module
322+
323+
# FBRANCASI: Store original function
324+
original_inject = injection_module.injectCustomForwards
325+
326+
# FBRANCASI: Override with our custom function
327+
injection_module.injectCustomForwards = injectCustomForwards
328+
329+
# FBRANCASI: Force reload of Export module to pick up the override
330+
import importlib
331+
332+
import DeepQuant.Export
333+
334+
importlib.reload(DeepQuant.Export)
335+
336+
try:
337+
from DeepQuant.Export import brevitasToTrueQuant
338+
339+
TQModel = brevitasToTrueQuant(FQModel, sampleInput, debug=True)
340+
finally:
341+
# FBRANCASI: Restore original function and reload Export module again
342+
injection_module.injectCustomForwards = original_inject
343+
importlib.reload(DeepQuant.Export)
269344

270345
numParameters = sum(p.numel() for p in TQModel.parameters())
271346
print(f"Number of parameters: {numParameters:,}")

0 commit comments

Comments
 (0)