Skip to content

Commit e1eadd0

Browse files
committed
Add unit layer tests, MBNetV3 test, and mark tests into two categories
1 parent e36a703 commit e1eadd0

File tree

11 files changed

+282
-146
lines changed

11 files changed

+282
-146
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ jobs:
3636
pip install -e .
3737
- name: Run Tests
3838
run: |
39-
pytest Tests/TestSimpleCNN.py
40-
pytest Tests/TestSimpleMHA.py
41-
pytest Tests/TestSimpleNN.py
39+
pytest -m SingleLayerTests
4240
4341
model-tests:
4442
runs-on: ubuntu-latest
@@ -55,4 +53,4 @@ jobs:
5553
pip install -e .
5654
- name: Run Tests
5755
run: |
58-
pytest Tests/TestMnist.py
56+
pytest -m ModelTests

DeepQuant/ExportBrevitas.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ def exportBrevitas(
9595
exampleInput
9696
) # Compute original model output on example input for validation
9797

98-
export_onnx_qcdq( # Export original model to ONNX format with QCDQ (Quant-Cast-DeQuant) nodes
99-
model, # Model to export
100-
args=exampleInput, # Example input for tracing
101-
export_path=EXPORT_FOLDER / "1_model_qcdq_original.onnx",
102-
opset_version=13,
103-
)
98+
# export_onnx_qcdq( # Export original model to ONNX format with QCDQ (Quant-Cast-DeQuant) nodes
99+
# model, # Model to export
100+
# args=exampleInput, # Example input for tracing
101+
# export_path=EXPORT_FOLDER / "1_model_qcdq_original.onnx",
102+
# opset_version=13,
103+
# )
104104

105105
###############################################################################
106106
# 2. Injection of New Modules
@@ -151,12 +151,12 @@ def exportBrevitas(
151151
print("\n=== 2. Network after the Injection of New Modules ===\n")
152152
printer.print_tabular(fxModel)
153153

154-
export_onnx_qcdq( # Export transformed model to ONNX
155-
fxModel, # Transformed model
156-
args=exampleInput,
157-
export_path=EXPORT_FOLDER / "2_model_qcdq_transformed.onnx",
158-
opset_version=13,
159-
)
154+
# export_onnx_qcdq( # Export transformed model to ONNX
155+
# fxModel, # Transformed model
156+
# args=exampleInput,
157+
# export_path=EXPORT_FOLDER / "2_model_qcdq_transformed.onnx",
158+
# opset_version=13,
159+
# )
160160

161161
###############################################################################
162162
# 3. Extraction of Parameters & Split of Quant Nodes
@@ -274,15 +274,6 @@ def exportBrevitas(
274274
f"{RED} ✗ Modification of Dequant Nodes changed the output significantly{ENDC}"
275275
)
276276

277-
# try:
278-
# tracer = NodeTracer(debug=True)
279-
# tracer.trace(fx_model_unified, example_input)
280-
# if debug:
281-
# print(f"{BLUE} ✓ Tracing completed{ENDC}")
282-
# except Exception as e:
283-
# print(f"{RED} ✗ Tracing failed: {str(e)}{ENDC}")
284-
# print("This doesn't affect the validity of the exported model")
285-
286277
import numpy as np
287278
import onnxruntime as ort
288279
import onnx

Tests/TestConv.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Victor Jung <[email protected]>
6+
# Federico Brancasi <[email protected]>
7+
8+
9+
import pytest
10+
import torch
11+
import torch.nn as nn
12+
import brevitas.nn as qnn
13+
from brevitas.quant.scaled_int import (
14+
Int8ActPerTensorFloat,
15+
Int32Bias,
16+
Int8WeightPerTensorFloat,
17+
)
18+
from DeepQuant.ExportBrevitas import exportBrevitas
19+
20+
21+
class QuantConvNet(nn.Module):
22+
23+
convAndLinQuantParams = {
24+
"bias": True,
25+
"weight_bit_width": 4,
26+
"bias_quant": Int32Bias,
27+
"input_quant": Int8ActPerTensorFloat,
28+
"weight_quant": Int8WeightPerTensorFloat,
29+
"output_quant": Int8ActPerTensorFloat,
30+
"return_quant_tensor": True,
31+
}
32+
33+
def __init__(self, in_channels: int = 1) -> None:
34+
super().__init__()
35+
self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True)
36+
37+
self.conv1 = qnn.QuantConv2d(
38+
in_channels=in_channels,
39+
out_channels=16,
40+
kernel_size=3,
41+
padding=1,
42+
**QuantConvNet.convAndLinQuantParams
43+
)
44+
45+
def forward(self, x: torch.Tensor) -> torch.Tensor:
46+
47+
x = self.inputQuant(x)
48+
x = self.conv1(x)
49+
50+
return x
51+
52+
53+
@pytest.mark.SingleLayerTests
54+
def deepQuantTestConv() -> None:
55+
56+
torch.manual_seed(42)
57+
58+
model = QuantConvNet().eval()
59+
sampleInput = torch.randn(1, 1, 28, 28)
60+
exportBrevitas(model, sampleInput, debug=True)

Tests/TestLinear.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Federico Brancasi <[email protected]>
6+
7+
8+
import pytest
9+
10+
### PyTorch Imports ###
11+
import torch
12+
import torch.nn as nn
13+
14+
### Brevitas Import ###
15+
import brevitas.nn as qnn
16+
from brevitas.quant.scaled_int import (
17+
Int8ActPerTensorFloat,
18+
Int32Bias,
19+
Int8WeightPerTensorFloat,
20+
)
21+
from DeepQuant.ExportBrevitas import exportBrevitas
22+
23+
24+
class QuantLinearNet(nn.Module):
25+
26+
def __init__(self, in_features: int = 16, hidden_features: int = 32) -> None:
27+
super().__init__()
28+
29+
self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True)
30+
31+
self.linear1 = qnn.QuantLinear(
32+
in_features=in_features,
33+
out_features=hidden_features,
34+
bias=True,
35+
weight_bit_width=4,
36+
bias_quant=Int32Bias,
37+
output_quant=Int8ActPerTensorFloat,
38+
input_quant=Int8ActPerTensorFloat,
39+
weight_quant=Int8WeightPerTensorFloat,
40+
return_quant_tensor=True,
41+
)
42+
43+
def forward(self, x: torch.Tensor) -> torch.Tensor:
44+
45+
x = self.inputQuant(x)
46+
x = self.linear1(x)
47+
48+
return x
49+
50+
51+
@pytest.mark.SingleLayerTests
52+
def deepQuantTestLinear() -> None:
53+
54+
torch.manual_seed(42)
55+
56+
model = QuantLinearNet().eval()
57+
sampleInput = torch.randn(1, 4, 16)
58+
59+
exportBrevitas(model, sampleInput, debug=True)
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Federico Brancasi <[email protected]>
66

77

8+
import pytest
89
import torch
910
import torch.nn as nn
1011
import brevitas.nn as qnn
@@ -19,7 +20,7 @@
1920
)
2021

2122

22-
class SimpleQuantMHA(nn.Module):
23+
class QuantMHSANet(nn.Module):
2324

2425
def __init__(self, embed_dim: int, num_heads: int) -> None:
2526
"""
@@ -65,11 +66,12 @@ def forward(self, x: Tensor) -> Tensor:
6566
return out
6667

6768

68-
def deepQuantTestSimpleQuantMHA() -> None:
69+
@pytest.mark.SingleLayerTests
70+
def deepQuantTestMHSA() -> None:
6971

7072
torch.manual_seed(42)
7173

72-
model = SimpleQuantMHA(embed_dim=16, num_heads=4).eval()
74+
model = QuantMHSANet(embed_dim=16, num_heads=4).eval()
7375
sampleInput = torch.randn(10, 2, 16)
7476

7577
exportBrevitas(model, sampleInput, debug=True)

Tests/TestMobileNetV3Small.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Victor Juing <[email protected]>
6+
7+
import pytest
8+
import torch
9+
import torch.nn as nn
10+
import torchvision.models as models
11+
from brevitas.graph.quantize import preprocess_for_quantize
12+
from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool
13+
import brevitas.nn as qnn
14+
from brevitas.quant import (
15+
Int8ActPerTensorFloat,
16+
Int8WeightPerTensorFloat,
17+
Int32Bias,
18+
Uint8ActPerTensorFloat,
19+
)
20+
from brevitas.graph.quantize import quantize
21+
22+
from DeepQuant.ExportBrevitas import exportBrevitas
23+
24+
25+
def prepareMBNetV3Model() -> nn.Module:
26+
"""
27+
Prepare a quantized MobileNetV3Small model for testing.
28+
Steps:
29+
1) Load the torchvision MobileNetV3Small.
30+
2) Convert it to eval mode.
31+
3) Preprocess and adapt average pooling.
32+
4) Quantize it using Brevitas.
33+
34+
Returns:
35+
A quantized MobileNetV3Small model ready for export tests.
36+
"""
37+
baseModel = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1)
38+
baseModel = baseModel.eval()
39+
40+
computeLayerMap = {
41+
nn.Conv2d: (
42+
qnn.QuantConv2d,
43+
{
44+
"input_quant": Int8ActPerTensorFloat,
45+
"weight_quant": Int8WeightPerTensorFloat,
46+
"output_quant": Int8ActPerTensorFloat,
47+
"bias_quant": Int32Bias,
48+
"bias": True,
49+
"return_quant_tensor": True,
50+
"output_bit_width": 8,
51+
"weight_bit_width": 4,
52+
},
53+
),
54+
nn.Linear: (
55+
qnn.QuantLinear,
56+
{
57+
"input_quant": Int8ActPerTensorFloat,
58+
"weight_quant": Int8WeightPerTensorFloat,
59+
"output_quant": Int8ActPerTensorFloat,
60+
"bias_quant": Int32Bias,
61+
"bias": True,
62+
"return_quant_tensor": True,
63+
"output_bit_width": 8,
64+
"weight_bit_width": 4,
65+
},
66+
),
67+
}
68+
69+
quantActMap = {
70+
nn.ReLU: (
71+
qnn.QuantReLU,
72+
{
73+
"act_quant": Uint8ActPerTensorFloat,
74+
"return_quant_tensor": True,
75+
"bit_width": 8,
76+
},
77+
),
78+
}
79+
80+
quantIdentityMap = {
81+
"signed": (
82+
qnn.QuantIdentity,
83+
{
84+
"act_quant": Int8ActPerTensorFloat,
85+
"return_quant_tensor": True,
86+
"bit_width": 8,
87+
},
88+
),
89+
"unsigned": (
90+
qnn.QuantIdentity,
91+
{
92+
"act_quant": Uint8ActPerTensorFloat,
93+
"return_quant_tensor": True,
94+
"bit_width": 8,
95+
},
96+
),
97+
}
98+
99+
baseModel = preprocess_for_quantize(
100+
baseModel, equalize_iters=20, equalize_scale_computation="range"
101+
)
102+
baseModel = AdaptiveAvgPoolToAvgPool().apply(
103+
baseModel, torch.ones(1, 3, 224, 224)
104+
)
105+
106+
quantizedModel = quantize(
107+
graph_model=baseModel,
108+
compute_layer_map=computeLayerMap,
109+
quant_act_map=quantActMap,
110+
quant_identity_map=quantIdentityMap,
111+
)
112+
113+
return quantizedModel
114+
115+
116+
@pytest.mark.ModelTests
117+
def deepQuantTestMobileNetV3Small() -> None:
118+
119+
torch.manual_seed(42)
120+
121+
quantizedModel = prepareMBNetV3Model()
122+
sampleInput = torch.randn(1, 3, 224, 224)
123+
124+
exportBrevitas(quantizedModel, sampleInput, debug=True)

0 commit comments

Comments
 (0)