@@ -22,6 +22,7 @@ def brevitasToTrueQuant(
2222 exampleInput : torch .Tensor ,
2323 exportPath : Optional [Union [str , Path ]] = Path .cwd () / "Tests" / "ONNX" ,
2424 debug : bool = False ,
25+ checkEquivalence : bool = False ,
2526) -> nn .Module :
2627 """
2728 Export a Brevitas model to an FX GraphModule with unrolled quantization operations.
@@ -35,16 +36,18 @@ def brevitasToTrueQuant(
3536
3637 # Pipeline Step 2: Inject custom forward implementations
3738 transformedModel , transformedOutput = injectCustomForwards (
38- tracedModel , exampleInput , originalOutput , debug
39+ tracedModel , exampleInput , originalOutput , debug , checkEquivalence
3940 )
4041
4142 # Pipeline Step 3: Split quantization nodes
4243 splitModel , splitOutput = splitQuantNodes (
43- transformedModel , exampleInput , transformedOutput , debug
44+ transformedModel , exampleInput , transformedOutput , debug , checkEquivalence
4445 )
4546
4647 # Pipeline Step 4: Unify dequant nodes
47- unifiedModel , _ = mergeDequants (splitModel , exampleInput , splitOutput , debug )
48+ unifiedModel , _ = mergeDequants (
49+ splitModel , exampleInput , splitOutput , debug , checkEquivalence
50+ )
4851
4952 # Pipeline Step 5: Export to ONNX
5053 onnxFile , _ = exportToOnnx (unifiedModel , exampleInput , exportPath , debug )
0 commit comments