Skip to content

Commit 1a28452

Browse files
Add initial project files
0 parents  commit 1a28452

File tree

96 files changed

+3453
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+3453
-0
lines changed

Makefile

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Federico Brancasi <[email protected]>
6+
7+
.PHONY: test test-nn test-mha test-cnn
8+
9+
# Test Directory
10+
TEST_DIR = src/DeepQuant/tests
11+
12+
# Pytest flags
13+
PYTEST_FLAGS = -v -s
14+
15+
# Target for running all tests
16+
test:
17+
python -m pytest $(PYTEST_FLAGS) $(TEST_DIR)
18+
19+
# Target for running simple neural network test
20+
test-nn:
21+
python -m pytest $(PYTEST_FLAGS) $(TEST_DIR)/test_simple_nn.py
22+
23+
# Target for running multi-head attention test
24+
test-mha:
25+
python -m pytest $(PYTEST_FLAGS) $(TEST_DIR)/test_simple_mha.py
26+
27+
# Target for running convolutional neural network test
28+
test-cnn:
29+
python -m pytest $(PYTEST_FLAGS) $(TEST_DIR)/test_simple_cnn.py
30+
31+
# Target for running resnet test
32+
test-resnet:
33+
python -m pytest $(PYTEST_FLAGS) $(TEST_DIR)/test_resnet18.py
34+
35+
# Target for running mnist test
36+
test-mnist:
37+
python -m pytest $(PYTEST_FLAGS) $(TEST_DIR)/test_mnist.py
38+
39+
# Target for running a specific test (usage: make test-single TEST=test_simple_nn.py)
40+
test-single:
41+
ifdef TEST
42+
python -m pytest $(PYTEST_FLAGS) $(TEST_DIR)/$(TEST)
43+
else
44+
@echo "Please specify a test file with TEST=filename.py"
45+
endif
46+
47+
# Show help
48+
help:
49+
@echo "Available targets:"
50+
@echo " make test - Run all tests"
51+
@echo " make test-nn - Run simple neural network tests"
52+
@echo " make test-mha - Run multi-head attention tests"
53+
@echo " make test-cnn - Run convolutional neural network tests"
54+
@echo " make test-single TEST=filename.py - Run a specific test file"

README.md

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# DeepQuant
2+
3+
A Python library for exporting Brevitas quantized neural networks.
4+
5+
## Installation
6+
7+
### Requirements
8+
9+
- Python 3.11 or higher
10+
- PyTorch 2.1.2 or higher
11+
- Brevitas 0.11.0 or higher
12+
13+
### Setup Environment
14+
15+
First, create and activate a new conda environment:
16+
17+
```bash
18+
mamba create -n brevitas_env python=3.11
19+
mamba activate brevitas_env
20+
```
21+
22+
### Install Dependencies
23+
24+
Install PyTorch and its related packages:
25+
26+
```bash
27+
mamba install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 -c pytorch
28+
```
29+
30+
### Install the Package
31+
32+
Clone the repository and install in development mode:
33+
34+
```bash
35+
cd DeepQuant
36+
pip install -e .
37+
```
38+
39+
## Running Tests
40+
41+
### Using Make (Recommended)
42+
43+
The project includes a Makefile with several testing commands:
44+
45+
```bash
46+
# Run all tests with verbose output
47+
make test
48+
49+
# Run only neural network test
50+
make test-nn
51+
52+
# Run only multi-head attention test
53+
make test-mha
54+
55+
# Run only CNN test
56+
make test-cnn
57+
58+
# Run only Resnet18 test
59+
make test-resnet
60+
61+
# Run a specific test file
62+
make test-single TEST=test_simple_nn.py
63+
64+
# Show all available make commands
65+
make help
66+
```
67+
68+
### Using pytest directly
69+
70+
You can also run tests using pytest commands:
71+
72+
```bash
73+
# Run all tests
74+
python -m pytest src/DeepQuant/tests -v -s
75+
76+
# Run a specific test file
77+
python -m pytest src/DeepQuant/tests/test_simple_nn.py -v -s
78+
```
79+
80+
## Project Structure
81+
82+
```
83+
DeepQuant/
84+
├── Makefile
85+
├── pyproject.toml
86+
├── conftest.py
87+
└── src/
88+
└── DeepQuant/
89+
├── custom_forwards/
90+
│ ├── activations.py
91+
│ ├── linear.py
92+
│ └── multiheadattention.py
93+
├── injects/
94+
│ ├── base.py
95+
│ ├── executor.py
96+
│ └── transformations.py
97+
├── tests/
98+
│ ├── test_simple_mha.py
99+
│ ├── test_simple_nn.py
100+
│ └── test_simple_cnn.py
101+
├── custom_tracer.py
102+
└── export_brevitas.py
103+
```
104+
105+
### Key Components
106+
107+
- **Makefile**: Provides automation commands for testing
108+
- **pyproject.toml**: Defines project metadata and dependencies for editable installation
109+
- **conftest.py**: Pytest configuration file that handles warning filters
110+
111+
The source code is organized into several key modules:
112+
113+
- **custom_forwards/**: Contains the unrolled forward implementations for:
114+
115+
- Linear layers (QuantLinear, QuantConv2d)
116+
- Activation functions (QuantReLU, QuantSigmoid, etc.)
117+
- Multi-head attention (QuantMultiheadAttention)
118+
119+
- **injects/**: Contains the transformation infrastructure:
120+
121+
- Base transformation class and executor
122+
- Module-specific transformations
123+
- Validation and verification logic
124+
125+
- **tests/**: Example tests demonstrating the exporter usage:
126+
127+
- Simple neural network (linear + activations)
128+
- Multi-head attention model
129+
- Convolutional neural network
130+
- Resnet18
131+
132+
- **custom_tracer.py**: Implements a specialized `CustomBrevitasTracer` for FX tracing
133+
134+
- Handles Brevitas-specific module traversal
135+
- Ensures proper graph capture of quantization operations
136+
137+
- **export_brevitas.py**: Main API for end-to-end model export:
138+
- Orchestrates the transformation passes
139+
- Performs the final FX tracing
140+
- Validates model outputs through the process
141+
142+
## Usage
143+
144+
### Main Function: exportBrevitas
145+
146+
The main function of this library is `exportBrevitas`, which exports a Brevitas-based model to an FX GraphModule with unrolled quantization steps.
147+
148+
```python
149+
from DeepQuant.export_brevitas import exportBrevitas
150+
151+
# Initialize your Brevitas model
152+
model = YourBrevitasModel().eval()
153+
154+
# Create an input with the correct shape
155+
input = torch.randn(1, input_channels, height, width)
156+
157+
# Export the model (with debug information)
158+
fx_model = exportBrevitas(model, input, debug=True)
159+
```
160+
161+
Arguments:
162+
163+
- `model`: The Brevitas-based model to export
164+
- `example_input`: A representative input tensor for shape tracing
165+
- `debug`: If True, prints transformation progress (default: False)
166+
167+
When `debug=True`, you'll see the output showing the progress, for example:
168+
169+
```
170+
✓ MHA transformation successful - outputs match
171+
✓ Linear transformation successful - outputs match
172+
✓ Activation transformation successful - outputs match
173+
All transformations completed successfully!
174+
```
175+
176+
### Example Usage
177+
178+
A simple example script can be found in `example_usage.py` in the root directory of the project.
179+
180+
```python
181+
import torch
182+
import torch.nn as nn
183+
import brevitas.nn as qnn
184+
from brevitas.quant.scaled_int import Int8ActPerTensorFloat, Int32Bias
185+
from DeepQuant.export_brevitas import exportBrevitas
186+
187+
# Define a simple quantized model
188+
class SimpleQuantModel(nn.Module):
189+
def __init__(self):
190+
super().__init__()
191+
self.input_quant = qnn.QuantIdentity(return_quant_tensor=True)
192+
self.conv = qnn.QuantConv2d(
193+
in_channels=3,
194+
out_channels=16,
195+
kernel_size=3,
196+
bias=True,
197+
weight_bit_width=4,
198+
bias_quant=Int32Bias,
199+
output_quant=Int8ActPerTensorFloat,
200+
)
201+
202+
def forward(self, x):
203+
x = self.input_quant(x)
204+
x = self.conv(x)
205+
return x
206+
207+
# Export the model
208+
model = SimpleQuantModel().eval()
209+
dummy_input = torch.randn(1, 3, 32, 32)
210+
fx_model = exportBrevitas(model, dummy_input, debug=True)
211+
```
1.58 KB
Binary file not shown.
1.58 KB
Binary file not shown.

conftest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Federico Brancasi <[email protected]>
6+
7+
"""
8+
Pytest configuration file that suppresses specific warnings, including those
9+
related to torch.tensor constant registration in FX tracing.
10+
"""
11+
12+
import warnings
13+
import pytest
14+
15+
# Attempt to import TracerWarning from torch.fx.proxy;
16+
# if unavailable, skip filtering by category.
17+
try:
18+
from torch.fx.proxy import TracerWarning
19+
20+
warnings.filterwarnings("ignore", category=TracerWarning)
21+
except ImportError:
22+
pass
23+
24+
warnings.filterwarnings("ignore", category=DeprecationWarning)
25+
warnings.filterwarnings("ignore", category=UserWarning, message="Named tensors.*")
26+
warnings.filterwarnings(
27+
"ignore", category=UserWarning, message=".*__torch_function__.*"
28+
)
29+
warnings.filterwarnings(
30+
"ignore", category=UserWarning, message="Was not able to add assertion.*"
31+
)
32+
warnings.filterwarnings(
33+
"ignore", category=UserWarning, message="'has_cuda' is deprecated.*"
34+
)
35+
warnings.filterwarnings(
36+
"ignore", category=UserWarning, message="'has_cudnn' is deprecated.*"
37+
)
38+
warnings.filterwarnings(
39+
"ignore", category=UserWarning, message="'has_mps' is deprecated.*"
40+
)
41+
warnings.filterwarnings(
42+
"ignore", category=UserWarning, message="'has_mkldnn' is deprecated.*"
43+
)

custom_forwards.egg-info/PKG-INFO

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Metadata-Version: 2.2
2+
Name: custom_forwards
3+
Version: 0.1.0
4+
Summary: Custom PyTorch forwards library
5+
Requires-Python: >=3.11
6+
Requires-Dist: torch>=2.1.2
7+
Requires-Dist: torchvision>=0.16.2
8+
Requires-Dist: torchaudio>=2.1.2
9+
Requires-Dist: brevitas
10+
Requires-Dist: torchmetrics
11+
Provides-Extra: dev
12+
Requires-Dist: black; extra == "dev"
13+
Requires-Dist: isort; extra == "dev"
14+
Requires-Dist: pytest; extra == "dev"
15+
Requires-Dist: netron; extra == "dev"
16+
Requires-Dist: tabulate; extra == "dev"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
pyproject.toml
2+
setup.py
3+
custom_forwards.egg-info/PKG-INFO
4+
custom_forwards.egg-info/SOURCES.txt
5+
custom_forwards.egg-info/dependency_links.txt
6+
custom_forwards.egg-info/requires.txt
7+
custom_forwards.egg-info/top_level.txt
8+
src/__init__.py
9+
src/custom_tracer.py
10+
src/export_brevitas.py
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
torch>=2.1.2
2+
torchvision>=0.16.2
3+
torchaudio>=2.1.2
4+
brevitas
5+
torchmetrics
6+
7+
[dev]
8+
black
9+
isort
10+
pytest
11+
netron
12+
tabulate
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
src

0 commit comments

Comments
 (0)