Skip to content

Commit 2277543

Browse files
Add checkEquivalence Flag to brevitasToTrueQuant function
1 parent 2cbc76c commit 2277543

File tree

8 files changed

+35
-26
lines changed

8 files changed

+35
-26
lines changed

DeepQuant/Export.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def brevitasToTrueQuant(
2222
exampleInput: torch.Tensor,
2323
exportPath: Optional[Union[str, Path]] = Path.cwd() / "Tests" / "ONNX",
2424
debug: bool = False,
25+
checkEquivalence: bool = False,
2526
) -> nn.Module:
2627
"""
2728
Export a Brevitas model to an FX GraphModule with unrolled quantization operations.
@@ -35,16 +36,18 @@ def brevitasToTrueQuant(
3536

3637
# Pipeline Step 2: Inject custom forward implementations
3738
transformedModel, transformedOutput = injectCustomForwards(
38-
tracedModel, exampleInput, originalOutput, debug
39+
tracedModel, exampleInput, originalOutput, debug, checkEquivalence
3940
)
4041

4142
# Pipeline Step 3: Split quantization nodes
4243
splitModel, splitOutput = splitQuantNodes(
43-
transformedModel, exampleInput, transformedOutput, debug
44+
transformedModel, exampleInput, transformedOutput, debug, checkEquivalence
4445
)
4546

4647
# Pipeline Step 4: Unify dequant nodes
47-
unifiedModel, _ = mergeDequants(splitModel, exampleInput, splitOutput, debug)
48+
unifiedModel, _ = mergeDequants(
49+
splitModel, exampleInput, splitOutput, debug, checkEquivalence
50+
)
4851

4952
# Pipeline Step 5: Export to ONNX
5053
onnxFile, _ = exportToOnnx(unifiedModel, exampleInput, exportPath, debug)

DeepQuant/Pipeline/DequantUnify.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def mergeDequants(
2020
exampleInput: torch.Tensor,
2121
referenceOutput: torch.Tensor,
2222
debug: bool = False,
23+
checkEquivalence: bool = False,
2324
) -> Tuple[nn.Module, torch.Tensor]:
2425
"""
2526
Unify dequantization nodes to enable integer-only computation.
@@ -78,12 +79,13 @@ def mergeDequants(
7879
output = unifiedModel(exampleInput)
7980

8081
# FBRANCASI: Check output equivalence with a warning instead of error
81-
if not torch.allclose(referenceOutput, output, atol=1e-5) and debug:
82-
print(
83-
cc.warning(
84-
"Modification of Dequant Nodes may have changed the output slightly"
82+
if checkEquivalence:
83+
if not torch.allclose(referenceOutput, output, atol=1e-5) and debug:
84+
print(
85+
cc.warning(
86+
"Modification of Dequant Nodes may have changed the output slightly"
87+
)
8588
)
86-
)
8789

8890
if debug:
8991
# FBRANCASI: Register hooks for the unified model and compare tensors

DeepQuant/Pipeline/Injection.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def injectCustomForwards(
2525
exampleInput: torch.Tensor,
2626
referenceOutput: torch.Tensor,
2727
debug: bool = False,
28+
checkEquivalence: bool = False,
2829
) -> Tuple[nn.Module, torch.Tensor]:
2930
"""Inject custom forward implementations into the model."""
3031
printer = GraphModulePrinter()
@@ -49,13 +50,14 @@ def injectCustomForwards(
4950
with torch.no_grad():
5051
output = fxModel(exampleInput)
5152

52-
if torch.allclose(referenceOutput, output, atol=1e-5):
53-
if debug:
54-
print(cc.success("Injection of New Modules: output is consistent"))
55-
else:
56-
raise RuntimeError(
57-
cc.error("Injection of New Modules changed the output significantly")
58-
)
53+
if checkEquivalence:
54+
if torch.allclose(referenceOutput, output, atol=1e-5):
55+
if debug:
56+
print(cc.success("Injection of New Modules: output is consistent"))
57+
else:
58+
raise RuntimeError(
59+
cc.error("Injection of New Modules changed the output significantly")
60+
)
5961

6062
if debug:
6163
print(cc.header("2. Network after Injection of New Modules"))

DeepQuant/Pipeline/QuantSplit.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def splitQuantNodes(
2323
exampleInput: torch.Tensor,
2424
referenceOutput: torch.Tensor,
2525
debug: bool = False,
26+
checkEquivalence: bool = False,
2627
) -> Tuple[nn.Module, torch.Tensor]:
2728
"""
2829
Split quantization nodes into separate Quant and Dequant nodes.
@@ -44,13 +45,14 @@ def splitQuantNodes(
4445
with torch.no_grad():
4546
output = splitModel(exampleInput)
4647

47-
if torch.allclose(referenceOutput, output, atol=1e-5):
48-
if debug:
49-
print(cc.success("Split of Quant Nodes: output is consistent"))
50-
else:
51-
raise RuntimeError(
52-
cc.error("Split of Quant Nodes changed the output significantly")
53-
)
48+
if checkEquivalence:
49+
if torch.allclose(referenceOutput, output, atol=1e-5):
50+
if debug:
51+
print(cc.success("Split of Quant Nodes: output is consistent"))
52+
else:
53+
raise RuntimeError(
54+
cc.error("Split of Quant Nodes changed the output significantly")
55+
)
5456

5557
if debug:
5658
print(cc.header("3. Network after Split of Quant Nodes"))

Tests/TestConv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ def deepQuantTestConv() -> None:
5353
torch.manual_seed(42)
5454
model = QuantConvNet().eval()
5555
sampleInput = torch.randn(1, 1, 28, 28)
56-
brevitasToTrueQuant(model, sampleInput, debug=True)
56+
brevitasToTrueQuant(model, sampleInput, debug=True, checkEquivalence=True)

Tests/TestLinear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ def deepQuantTestLinear() -> None:
4646
torch.manual_seed(42)
4747
model = QuantLinearNet().eval()
4848
sampleInput = torch.randn(1, 4, 16)
49-
brevitasToTrueQuant(model, sampleInput, debug=True)
49+
brevitasToTrueQuant(model, sampleInput, debug=True, checkEquivalence=True)

Tests/TestMHSA.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,4 @@ def deepQuantTestMHSA() -> None:
5656
torch.manual_seed(42)
5757
model = QuantMHSANet(embedDim=16, numHeads=4).eval()
5858
sampleInput = torch.randn(10, 2, 16)
59-
brevitasToTrueQuant(model, sampleInput)
59+
brevitasToTrueQuant(model, sampleInput, checkEquivalence=True)

Tests/TestSimpleCNN.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,4 @@ def deepQuantTestSimpleCNN() -> None:
8282
torch.manual_seed(42)
8383
model = SimpleQuantCNN().eval()
8484
sampleInput = torch.randn(1, 1, 28, 28)
85-
brevitasToTrueQuant(model, sampleInput, debug=True)
85+
brevitasToTrueQuant(model, sampleInput, debug=True, checkEquivalence=True)

0 commit comments

Comments
 (0)