Skip to content

Commit 98c828c

Browse files
Refactor codebase & fix ResNet-18 test (#1)
* Initial commit fbrancasi/dev * Working Resnet18 * Codebase Refactor * update Resnet18 test * Fix CI * Minor Fixes
1 parent 0a0ea5b commit 98c828c

40 files changed

+1667
-2018
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dist/
2525
*.gz
2626
*-ubyte
2727
*.pth
28+
*.pt
2829
*.onnx
2930
*.npz
3031
onnx/*

DeepQuant/CustomForwards/Activations.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,63 +4,28 @@
44
#
55
# Federico Brancasi <[email protected]>
66

7-
87
import torch.nn as nn
9-
from torch import Tensor
108
from brevitas.nn.quant_layer import QuantNonLinearActLayer
9+
from torch import Tensor
1110

1211

13-
class InnerForwardImplWrapperActivation(nn.Module):
14-
"""
15-
A small wrapper around the activation function of a Brevitas QuantActivation layer.
16-
17-
This wrapper exposes the original activation function as a standalone submodule
18-
so that FX tracing can display it as a separate node.
19-
"""
12+
class WrapperActivation(nn.Module):
13+
"""Expose inner activation so FX sees it as a leaf."""
2014

2115
def __init__(self, actImpl: nn.Module) -> None:
22-
"""
23-
Args:
24-
act_impl: The original activation function module (e.g. an instance of nn.ReLU).
25-
"""
2616
super().__init__()
2717
self.actImpl = actImpl
2818

2919
def forward(self, quantInput: Tensor) -> Tensor:
30-
"""
31-
Applies the wrapped activation function.
32-
33-
Args:
34-
quant_input: Input tensor after input quantization.
35-
36-
Returns:
37-
Output tensor after applying the activation.
38-
"""
3920
return self.actImpl(quantInput)
4021

4122

42-
def quantActivationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor:
43-
"""
44-
Unrolled forward pass for a Brevitas QuantActivation layer.
45-
46-
Steps:
47-
1) Apply self.input_quant to the input.
48-
2) Apply the activation function via the wrapped activation implementation.
49-
3) Apply self.act_quant to the activation output.
50-
51-
Args:
52-
self: The QuantNonLinearActLayer instance.
53-
inp: The input tensor.
54-
55-
Returns:
56-
Output tensor after applying activation and output quantization.
57-
"""
23+
def activationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor:
24+
"""Unroll input→act→output quant steps."""
5825
quantInput = self.input_quant(inp) if self.input_quant is not None else inp
59-
# Use the wrapped activation if available; otherwise pass through.
6026
if hasattr(self, "wrappedActImpl"):
6127
output = self.wrappedActImpl(quantInput)
6228
else:
6329
output = quantInput
64-
import IPython; IPython.embed()
6530
quantOutput = self.act_quant(output) if self.act_quant is not None else output
6631
return quantOutput

DeepQuant/CustomForwards/Linear.py

Lines changed: 0 additions & 75 deletions
This file was deleted.

DeepQuant/CustomForwards/MultiHeadAttention.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,22 @@
44
#
55
# Federico Brancasi <[email protected]>
66

7-
87
import math
8+
99
import torch
1010
import torch.nn.functional as F
11-
from torch import Tensor
1211
from brevitas.nn.quant_mha import QuantMultiheadAttention
12+
from torch import Tensor
1313

1414

15-
def unrolledQuantMhaForward(
15+
def mhaForward(
1616
self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor
1717
) -> Tensor:
18-
"""
19-
Export-friendly forward that explicitly unrolls the multi-head logic.
20-
21-
Steps:
22-
1) Q, K, V projections
23-
2) Reshapes & permutes for multi-head
24-
3) Scales queries
25-
4) Applies softmax and intermediate quantizations
26-
5) Out projection
27-
28-
Args:
29-
self: The QuantMultiheadAttention instance.
30-
query: The query tensor of shape [sequence_len, batch_size, embed_dim].
31-
key: The key tensor, same shape as query.
32-
value: The value tensor, same shape as query.
33-
34-
Returns:
35-
A torch.Tensor of shape [sequence_len, batch_size, embed_dim]
36-
after the unrolled MHA steps.
37-
"""
38-
# 1) Q, K, V projections
18+
"""Explicit, export-friendly MHA forward."""
3919
qOut = self.q_proj(query)
4020
kOut = self.k_proj(key)
4121
vOut = self.v_proj(value)
4222

43-
# 2) Multi-head reshape
4423
seqLen, batchSize, embedDim = qOut.shape
4524
headDim = embedDim // self.num_heads
4625

@@ -60,11 +39,9 @@ def unrolledQuantMhaForward(
6039
.reshape(batchSize * self.num_heads, seqLen, headDim)
6140
)
6241

63-
# 3) Scale queries, then quantize
6442
qScaled = qOut / math.sqrt(headDim)
6543
qScaled = self.q_scaled_quant(qScaled)
6644

67-
# 4) Transpose + quantize K, compute attention weights
6845
k_t = kOut.transpose(-2, -1)
6946
k_t = self.k_transposed_quant(k_t)
7047

@@ -73,7 +50,6 @@ def unrolledQuantMhaForward(
7350
attnWeights = F.softmax(attnWeights, dim=-1)
7451
attnWeights = self.attn_output_weights_quant(attnWeights)
7552

76-
# 5) Quantize V, multiply, reshape back, and final out projection
7753
vOut = self.v_quant(vOut)
7854
attnOutput = torch.bmm(attnWeights, vOut)
7955

DeepQuant/CustomForwards/WBIOL.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2025 ETH Zurich and University of Bologna.
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Federico Brancasi <[email protected]>
6+
7+
import torch.nn as nn
8+
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer
9+
from torch import Tensor
10+
11+
12+
class WrapperWBIOL(nn.Module):
13+
"""Expose `inner_forward_impl` as a standalone submodule."""
14+
15+
def __init__(self, innerForwardImpl: nn.Module) -> None:
16+
super().__init__()
17+
self.innerForwardImpl = innerForwardImpl
18+
19+
def forward(
20+
self, quantInput: Tensor, quantWeight: Tensor, quantBias: Tensor
21+
) -> Tensor:
22+
return self.innerForwardImpl(quantInput, quantWeight, quantBias)
23+
24+
25+
def WBIOLForward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor:
26+
"""Quant-in → quant-weight/bias → matmul → quant-out."""
27+
quantInput = self.input_quant(inp)
28+
quantWeight = self.weight_quant(self.weight)
29+
30+
quantBias = None
31+
if self.bias is not None:
32+
quantBias = self.bias_quant(self.bias, quantInput, quantWeight)
33+
34+
output = self.wrappedInnerForwardImpl(quantInput, quantWeight, quantBias)
35+
quantOutput = self.output_quant(output)
36+
return quantOutput

DeepQuant/CustomTracer.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

0 commit comments

Comments
 (0)