|
4 | 4 | # |
5 | 5 | # Federico Brancasi <[email protected]> |
6 | 6 |
|
| 7 | +from typing import Tuple |
| 8 | + |
7 | 9 | import brevitas.nn as qnn |
8 | 10 | import pytest |
9 | 11 | import torch |
|
23 | 25 | from torch.utils.data import DataLoader, Subset |
24 | 26 | from tqdm import tqdm |
25 | 27 |
|
26 | | -from DeepQuant import brevitasToTrueQuant |
| 28 | +from DeepQuant.Transforms.Executor import TransformationExecutor |
| 29 | +from DeepQuant.Transforms.Transformations import LinearTransformation, MHATransformation |
| 30 | +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc |
| 31 | +from DeepQuant.Utils.CustomTracer import QuantTracer, customBrevitasTrace |
| 32 | +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter |
27 | 33 | from Tests.Models.CCT import cct_2_3x2_32 |
28 | 34 |
|
29 | 35 |
|
| 36 | +def injectCustomForwards( |
| 37 | + model: nn.Module, |
| 38 | + exampleInput: torch.Tensor, |
| 39 | + referenceOutput: torch.Tensor, |
| 40 | + debug: bool = False, |
| 41 | + checkEquivalence: bool = False, |
| 42 | +) -> Tuple[nn.Module, torch.Tensor]: |
| 43 | + """Custom inject function for CCT that excludes ActivationTransformation.""" |
| 44 | + printer = GraphModulePrinter() |
| 45 | + |
| 46 | + tracer = QuantTracer(debug=debug) |
| 47 | + |
| 48 | + transformations = [ |
| 49 | + MHATransformation(), |
| 50 | + LinearTransformation(), |
| 51 | + # ActivationTransformation(), # FBRANCASI: Commented out for CCT compatibility |
| 52 | + ] |
| 53 | + |
| 54 | + executor = TransformationExecutor(transformations, debug=debug, tracer=tracer) |
| 55 | + transformedModel = executor.execute(model, exampleInput) |
| 56 | + |
| 57 | + fxModel = customBrevitasTrace( |
| 58 | + root=transformedModel, |
| 59 | + tracer=tracer, |
| 60 | + ) |
| 61 | + fxModel.recompile() |
| 62 | + |
| 63 | + with torch.no_grad(): |
| 64 | + output = fxModel(exampleInput) |
| 65 | + |
| 66 | + if checkEquivalence: |
| 67 | + if torch.allclose(referenceOutput, output, atol=1e-5): |
| 68 | + if debug: |
| 69 | + print(cc.success("Injection of New Modules: output is consistent")) |
| 70 | + else: |
| 71 | + raise RuntimeError( |
| 72 | + cc.error("Injection of New Modules changed the output significantly") |
| 73 | + ) |
| 74 | + |
| 75 | + if debug: |
| 76 | + print(cc.header("2. Network after Injection of New Modules")) |
| 77 | + printer.printTabular(fxModel) |
| 78 | + print() |
| 79 | + |
| 80 | + return fxModel, output |
| 81 | + |
| 82 | + |
30 | 83 | def evaluateModel(model, dataLoader, evalDevice, name="Model"): |
31 | 84 | model.eval() |
32 | 85 | correct = 0 |
@@ -133,8 +186,6 @@ def prepareFQCCT(model) -> nn.Module: |
133 | 186 |
|
134 | 187 | quant_name = f"{node.name}_reshape_fix" |
135 | 188 | model.add_module(quant_name, quant_identity) |
136 | | - # mark this QuantIdentity as “reshape fix” |
137 | | - quant_identity._is_reshape_fix = True |
138 | 189 |
|
139 | 190 | with model.graph.inserting_after(node): |
140 | 191 | quant_node = model.graph.call_module(quant_name, args=(node,)) |
@@ -265,7 +316,31 @@ def deepQuantTestCCT(): |
265 | 316 | FQAccuracy = evaluateModel(FQModel, valLoader, device, "FQ CCT-2") |
266 | 317 |
|
267 | 318 | sampleInput = torch.randn(1, 3, 32, 32).to("cpu") |
268 | | - TQModel = brevitasToTrueQuant(FQModel, sampleInput, debug=True) |
| 319 | + |
| 320 | + # FBRANCASI: Override the injectCustomForwards function in the module before DeepQuant.Export imports it |
| 321 | + import DeepQuant.Pipeline.Injection as injection_module |
| 322 | + |
| 323 | + # FBRANCASI: Store original function |
| 324 | + original_inject = injection_module.injectCustomForwards |
| 325 | + |
| 326 | + # FBRANCASI: Override with our custom function |
| 327 | + injection_module.injectCustomForwards = injectCustomForwards |
| 328 | + |
| 329 | + # FBRANCASI: Force reload of Export module to pick up the override |
| 330 | + import importlib |
| 331 | + |
| 332 | + import DeepQuant.Export |
| 333 | + |
| 334 | + importlib.reload(DeepQuant.Export) |
| 335 | + |
| 336 | + try: |
| 337 | + from DeepQuant.Export import brevitasToTrueQuant |
| 338 | + |
| 339 | + TQModel = brevitasToTrueQuant(FQModel, sampleInput, debug=True) |
| 340 | + finally: |
| 341 | + # FBRANCASI: Restore original function and reload Export module again |
| 342 | + injection_module.injectCustomForwards = original_inject |
| 343 | + importlib.reload(DeepQuant.Export) |
269 | 344 |
|
270 | 345 | numParameters = sum(p.numel() for p in TQModel.parameters()) |
271 | 346 | print(f"Number of parameters: {numParameters:,}") |
|
0 commit comments