Skip to content

Commit 648155a

Browse files
committed
Add tests for channel wise weight quantization
1 parent 9402356 commit 648155a

File tree

2 files changed

+167
-0
lines changed

2 files changed

+167
-0
lines changed

Tests/TestConvChannelWise.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Victor Jung <jungvi@iis.ee.ethz.ch>
6+
# Federico Brancasi <fbrancasi@ethz.ch>
7+
8+
import pytest
9+
import torch
10+
import torch.nn as nn
11+
import brevitas.nn as qnn
12+
from brevitas.quant.scaled_int import (
13+
Int8ActPerTensorFloat,
14+
Int32Bias,
15+
Int8WeightPerChannelFloat
16+
)
17+
from DeepQuant.ExportBrevitas import exportBrevitas
18+
19+
20+
class QuantConvNet(nn.Module):
21+
22+
convAndLinQuantParams = {
23+
"bias": True,
24+
"weight_bit_width": 4,
25+
"bias_quant": Int32Bias,
26+
"input_quant": Int8ActPerTensorFloat,
27+
"weight_quant": Int8WeightPerChannelFloat,
28+
"output_quant": Int8ActPerTensorFloat,
29+
"return_quant_tensor": True,
30+
}
31+
32+
def __init__(self, in_channels: int = 1) -> None:
33+
super().__init__()
34+
self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True)
35+
36+
self.conv1 = qnn.QuantConv2d(
37+
in_channels=in_channels,
38+
out_channels=16,
39+
kernel_size=3,
40+
padding=1,
41+
**QuantConvNet.convAndLinQuantParams
42+
)
43+
44+
def forward(self, x: torch.Tensor) -> torch.Tensor:
45+
46+
x = self.inputQuant(x)
47+
x = self.conv1(x)
48+
49+
return x
50+
51+
52+
@pytest.mark.SingleLayerTests
53+
def deepQuantTestConv() -> None:
54+
55+
torch.manual_seed(42)
56+
57+
model = QuantConvNet().eval()
58+
sampleInput = torch.randn(1, 1, 28, 28)
59+
exportBrevitas(model, sampleInput, debug=True)

Tests/TestSimpleCNNChannelWise.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Federico Brancasi <fbrancasi@ethz.ch>
6+
7+
8+
import pytest
9+
import torch
10+
import torch.nn as nn
11+
import brevitas.nn as qnn
12+
from brevitas.quant.scaled_int import (
13+
Int8ActPerTensorFloat,
14+
Int32Bias,
15+
Int8WeightPerChannelFloat,
16+
)
17+
from DeepQuant.ExportBrevitas import exportBrevitas
18+
19+
20+
class SimpleQuantCNN(nn.Module):
21+
"""
22+
A simple quantized CNN that includes:
23+
- Input quantization
24+
- Two QuantConv2d layers with Quantized ReLU
25+
- MaxPool2d
26+
- A final QuantLinear layer
27+
"""
28+
29+
convAndLinQuantParams = {
30+
"bias": True,
31+
"weight_bit_width": 4,
32+
"bias_quant": Int32Bias,
33+
"input_quant": Int8ActPerTensorFloat,
34+
"weight_quant": Int8WeightPerChannelFloat,
35+
"output_quant": Int8ActPerTensorFloat,
36+
"return_quant_tensor": True,
37+
}
38+
39+
def __init__(self, in_channels: int = 1, num_classes: int = 10) -> None:
40+
"""
41+
Args:
42+
in_channels: Number of input channels (e.g., 1 for grayscale).
43+
num_classes: Number of output classes for the final linear layer.
44+
"""
45+
super().__init__()
46+
self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True)
47+
48+
self.conv1 = qnn.QuantConv2d(
49+
in_channels=in_channels,
50+
out_channels=16,
51+
kernel_size=3,
52+
padding=1,
53+
**SimpleQuantCNN.convAndLinQuantParams
54+
)
55+
self.relu1 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
56+
self.pool1 = nn.MaxPool2d(kernel_size=2)
57+
58+
self.conv2 = qnn.QuantConv2d(
59+
in_channels=16,
60+
out_channels=32,
61+
kernel_size=3,
62+
padding=1,
63+
**SimpleQuantCNN.convAndLinQuantParams
64+
)
65+
self.relu2 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
66+
self.pool2 = nn.MaxPool2d(kernel_size=2)
67+
68+
self.flatten = nn.Flatten()
69+
self.fc = qnn.QuantLinear(
70+
in_features=32 * 7 * 7, # If input is 28x28, shape after pooling is 7x7
71+
out_features=num_classes,
72+
**SimpleQuantCNN.convAndLinQuantParams
73+
)
74+
75+
def forward(self, x: torch.Tensor) -> torch.Tensor:
76+
"""
77+
Forward pass of the SimpleQuantCNN.
78+
79+
Args:
80+
x: Input tensor of shape [batch_size, in_channels, height, width].
81+
82+
Returns:
83+
A quantized output tensor (batch_size, num_classes).
84+
"""
85+
x = self.inputQuant(x)
86+
87+
x = self.conv1(x)
88+
x = self.relu1(x)
89+
x = self.pool1(x)
90+
91+
x = self.conv2(x)
92+
x = self.relu2(x)
93+
x = self.pool2(x)
94+
95+
x = self.flatten(x)
96+
x = self.fc(x)
97+
return x
98+
99+
100+
@pytest.mark.ModelTests
101+
def deepQuantTestSimpleCNN() -> None:
102+
103+
torch.manual_seed(42)
104+
105+
model = SimpleQuantCNN().eval()
106+
sampleInput = torch.randn(1, 1, 28, 28)
107+
108+
exportBrevitas(model, sampleInput, debug=True)

0 commit comments

Comments
 (0)