|
4 | 4 | # |
5 | 5 | # Federico Brancasi <[email protected]> |
6 | 6 |
|
| 7 | +from typing import Tuple |
| 8 | + |
7 | 9 | import brevitas.nn as qnn |
8 | 10 | import pytest |
9 | 11 | import torch |
|
18 | 20 | Uint8ActPerTensorFloat, |
19 | 21 | ) |
20 | 22 |
|
| 23 | +from DeepQuant.Transforms.Executor import TransformationExecutor |
| 24 | +from DeepQuant.Transforms.Transformations import LinearTransformation, MHATransformation |
| 25 | +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc |
| 26 | +from DeepQuant.Utils.CustomTracer import QuantTracer, customBrevitasTrace |
| 27 | +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter |
21 | 28 | from Tests.Models.CCT import cct_2_3x2_32 |
22 | 29 |
|
23 | 30 |
|
| 31 | +def injectCustomForwards( |
| 32 | + model: nn.Module, |
| 33 | + exampleInput: torch.Tensor, |
| 34 | + referenceOutput: torch.Tensor, |
| 35 | + debug: bool = False, |
| 36 | + checkEquivalence: bool = False, |
| 37 | +) -> Tuple[nn.Module, torch.Tensor]: |
| 38 | + """Custom inject function for CCT that excludes ActivationTransformation.""" |
| 39 | + printer = GraphModulePrinter() |
| 40 | + |
| 41 | + tracer = QuantTracer(debug=debug) |
| 42 | + |
| 43 | + transformations = [ |
| 44 | + MHATransformation(), |
| 45 | + LinearTransformation(), |
| 46 | + # ActivationTransformation(), # FBRANCASI: Commented out for CCT compatibility |
| 47 | + ] |
| 48 | + |
| 49 | + executor = TransformationExecutor(transformations, debug=debug, tracer=tracer) |
| 50 | + transformedModel = executor.execute(model, exampleInput) |
| 51 | + |
| 52 | + fxModel = customBrevitasTrace( |
| 53 | + root=transformedModel, |
| 54 | + tracer=tracer, |
| 55 | + ) |
| 56 | + fxModel.recompile() |
| 57 | + |
| 58 | + with torch.no_grad(): |
| 59 | + output = fxModel(exampleInput) |
| 60 | + |
| 61 | + if checkEquivalence: |
| 62 | + if torch.allclose(referenceOutput, output, atol=1e-5): |
| 63 | + if debug: |
| 64 | + print(cc.success("Injection of New Modules: output is consistent")) |
| 65 | + else: |
| 66 | + raise RuntimeError( |
| 67 | + cc.error("Injection of New Modules changed the output significantly") |
| 68 | + ) |
| 69 | + |
| 70 | + if debug: |
| 71 | + print(cc.header("2. Network after Injection of New Modules")) |
| 72 | + printer.printTabular(fxModel) |
| 73 | + print() |
| 74 | + |
| 75 | + return fxModel, output |
| 76 | + |
| 77 | + |
24 | 78 | def prepareCCT(model) -> nn.Module: |
25 | 79 | """ |
26 | 80 | Prepare a quantized CCT model for testing with export support. |
@@ -84,8 +138,6 @@ def prepareCCT(model) -> nn.Module: |
84 | 138 |
|
85 | 139 | quant_name = f"{node.name}_reshape_fix" |
86 | 140 | model.add_module(quant_name, quant_identity) |
87 | | - # mark this QuantIdentity as “reshape fix” |
88 | | - quant_identity._is_reshape_fix = True |
89 | 141 |
|
90 | 142 | with model.graph.inserting_after(node): |
91 | 143 | quant_node = model.graph.call_module(quant_name, args=(node,)) |
@@ -181,6 +233,27 @@ def deepQuantTestCCT(): |
181 | 233 | print(f"Output shape: {output.shape}") |
182 | 234 | print(f"Output range: [{output.min().item():.3f}, {output.max().item():.3f}]") |
183 | 235 |
|
184 | | - from DeepQuant import brevitasToTrueQuant |
| 236 | + # FBRANCASI: Override the injectCustomForwards function in the module before DeepQuant.Export imports it |
| 237 | + import DeepQuant.Pipeline.Injection as injection_module |
| 238 | + |
| 239 | + # FBRANCASI: Store original function |
| 240 | + original_inject = injection_module.injectCustomForwards |
| 241 | + |
| 242 | + # FBRANCASI: Override with our custom function |
| 243 | + injection_module.injectCustomForwards = injectCustomForwards |
| 244 | + |
| 245 | + # FBRANCASI: Force reload of Export module to pick up the override |
| 246 | + import importlib |
| 247 | + |
| 248 | + import DeepQuant.Export |
| 249 | + |
| 250 | + importlib.reload(DeepQuant.Export) |
| 251 | + |
| 252 | + try: |
| 253 | + from DeepQuant.Export import brevitasToTrueQuant |
185 | 254 |
|
186 | | - brevitasToTrueQuant(quantizedModel, sampleInput, debug=True) |
| 255 | + brevitasToTrueQuant(quantizedModel, sampleInput, debug=True) |
| 256 | + finally: |
| 257 | + # FBRANCASI: Restore original function and reload Export module again |
| 258 | + injection_module.injectCustomForwards = original_inject |
| 259 | + importlib.reload(DeepQuant.Export) |
0 commit comments