Skip to content

Commit e36a703

Browse files
committed
First round of refactoring
1 parent dd5b2ff commit e36a703

24 files changed

+356
-577
lines changed

.github/workflows/CI.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ jobs:
3636
pip install -e .
3737
- name: Run Tests
3838
run: |
39-
pytest Tests/test_simple_cnn.py
40-
pytest Tests/test_simple_mha.py
41-
pytest Tests/test_simple_nn.py
39+
pytest Tests/TestSimpleCNN.py
40+
pytest Tests/TestSimpleMHA.py
41+
pytest Tests/TestSimpleNN.py
4242
4343
model-tests:
4444
runs-on: ubuntu-latest
@@ -55,4 +55,4 @@ jobs:
5555
pip install -e .
5656
- name: Run Tests
5757
run: |
58-
pytest Tests/test_mnist.py
58+
pytest Tests/TestMnist.py
Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
#
55
# Federico Brancasi <[email protected]>
66

7-
"""
8-
Custom forward implementations for Brevitas QuantActivation layers.
9-
"""
107

11-
import torch
128
import torch.nn as nn
139
from torch import Tensor
1410
from brevitas.nn.quant_layer import QuantNonLinearActLayer
@@ -22,15 +18,15 @@ class InnerForwardImplWrapperActivation(nn.Module):
2218
so that FX tracing can display it as a separate node.
2319
"""
2420

25-
def __init__(self, act_impl: nn.Module) -> None:
21+
def __init__(self, actImpl: nn.Module) -> None:
2622
"""
2723
Args:
2824
act_impl: The original activation function module (e.g. an instance of nn.ReLU).
2925
"""
3026
super().__init__()
31-
self.act_impl = act_impl
27+
self.actImpl = actImpl
3228

33-
def forward(self, quant_input: Tensor) -> Tensor:
29+
def forward(self, quantInput: Tensor) -> Tensor:
3430
"""
3531
Applies the wrapped activation function.
3632
@@ -40,10 +36,10 @@ def forward(self, quant_input: Tensor) -> Tensor:
4036
Returns:
4137
Output tensor after applying the activation.
4238
"""
43-
return self.act_impl(quant_input)
39+
return self.actImpl(quantInput)
4440

4541

46-
def quant_activation_forward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor:
42+
def quantActivationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor:
4743
"""
4844
Unrolled forward pass for a Brevitas QuantActivation layer.
4945
@@ -59,11 +55,12 @@ def quant_activation_forward(self: QuantNonLinearActLayer, inp: Tensor) -> Tenso
5955
Returns:
6056
Output tensor after applying activation and output quantization.
6157
"""
62-
quant_input = self.input_quant(inp) if self.input_quant is not None else inp
58+
quantInput = self.input_quant(inp) if self.input_quant is not None else inp
6359
# Use the wrapped activation if available; otherwise pass through.
64-
if hasattr(self, "wrapped_act_impl"):
65-
output = self.wrapped_act_impl(quant_input)
60+
if hasattr(self, "wrappedActImpl"):
61+
output = self.wrappedActImpl(quantInput)
6662
else:
67-
output = quant_input
68-
quant_output = self.act_quant(output) if self.act_quant is not None else output
69-
return quant_output
63+
output = quantInput
64+
import IPython; IPython.embed()
65+
quantOutput = self.act_quant(output) if self.act_quant is not None else output
66+
return quantOutput
Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,35 @@
44
#
55
# Federico Brancasi <[email protected]>
66

7-
"""
8-
Custom forward implementations for Brevitas QuantLinear layers.
9-
"""
107

11-
import torch
128
import torch.nn as nn
139
from torch import Tensor
1410
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer
1511

1612

1713
class InnerForwardImplWrapperLinear(nn.Module):
1814
"""
19-
A small wrapper around the 'inner_forward_impl' of a Brevitas QuantLinear
15+
A small wrapper around the 'innerForwardImpl' of a Brevitas QuantLinear
2016
(QuantWeightBiasInputOutputLayer).
2117
22-
We want to expose the logic within 'inner_forward_impl' as a standalone
18+
We want to expose the logic within 'innerForwardImpl' as a standalone
2319
submodule, so that FX tracing can see it as a leaf.
2420
"""
2521

26-
def __init__(self, inner_forward_impl: nn.Module) -> None:
22+
def __init__(self, innerForwardImpl: nn.Module) -> None:
2723
"""
2824
Args:
29-
inner_forward_impl: The original function that processes
25+
innerForwardImpl: The original function that processes
3026
(quant_input, quant_weight, quant_bias).
3127
"""
3228
super().__init__()
33-
self.inner_forward_impl = inner_forward_impl
29+
self.innerForwardImpl = innerForwardImpl
3430

3531
def forward(
36-
self, quant_input: Tensor, quant_weight: Tensor, quant_bias: Tensor
32+
self, quantInput: Tensor, quantWeight: Tensor, quantBias: Tensor
3733
) -> Tensor:
3834
"""
39-
Applies the wrapped inner_forward_impl.
35+
Applies the wrapped innerForwardImpl.
4036
4137
Args:
4238
quant_input: Input after input_quant.
@@ -46,18 +42,18 @@ def forward(
4642
Returns:
4743
A torch.Tensor with the linear operation applied.
4844
"""
49-
return self.inner_forward_impl(quant_input, quant_weight, quant_bias)
45+
return self.innerForwardImpl(quantInput, quantWeight, quantBias)
5046

5147

52-
def quantWBIOL_forward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor:
48+
def quantWBIOLForward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor:
5349
"""
5450
Unrolled forward pass for a Brevitas QuantLinear:
5551
5652
Steps:
5753
1) self.input_quant
5854
2) self.weight_quant
5955
3) self.bias_quant (if bias is present)
60-
4) inner_forward_impl (wrapped)
56+
4) innerForwardImpl (wrapped)
6157
5) self.output_quant
6258
6359
Args:
@@ -67,13 +63,13 @@ def quantWBIOL_forward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Te
6763
Returns:
6864
Output Tensor after the unrolled quantized linear steps.
6965
"""
70-
quant_input = self.input_quant(inp)
71-
quant_weight = self.weight_quant(self.weight)
66+
quantInput = self.input_quant(inp)
67+
quantWeight = self.weight_quant(self.weight)
7268

73-
quant_bias = None
69+
quantBias = None
7470
if self.bias is not None:
75-
quant_bias = self.bias_quant(self.bias, quant_input, quant_weight)
71+
quantBias = self.bias_quant(self.bias, quantInput, quantWeight)
7672

77-
output = self.wrapped_inner_forward_impl(quant_input, quant_weight, quant_bias)
78-
quant_output = self.output_quant(output)
79-
return quant_output
73+
output = self.wrappedInnerForwardImpl(quantInput, quantWeight, quantBias)
74+
quantOutput = self.output_quant(output)
75+
return quantOutput

DeepQuant/custom_forwards/multiheadattention.py renamed to DeepQuant/CustomForwards/MultiHeadAttention.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
#
55
# Federico Brancasi <[email protected]>
66

7-
"""
8-
Custom forward implementation for Brevitas QuantMultiheadAttention.
9-
"""
107

118
import math
129
import torch
@@ -15,7 +12,7 @@
1512
from brevitas.nn.quant_mha import QuantMultiheadAttention
1613

1714

18-
def unrolled_quant_mha_forward(
15+
def unrolledQuantMhaForward(
1916
self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor
2017
) -> Tensor:
2118
"""
@@ -39,52 +36,52 @@ def unrolled_quant_mha_forward(
3936
after the unrolled MHA steps.
4037
"""
4138
# 1) Q, K, V projections
42-
q_out = self.q_proj(query)
43-
k_out = self.k_proj(key)
44-
v_out = self.v_proj(value)
39+
qOut = self.q_proj(query)
40+
kOut = self.k_proj(key)
41+
vOut = self.v_proj(value)
4542

4643
# 2) Multi-head reshape
47-
seq_len, batch_size, embed_dim = q_out.shape
48-
head_dim = embed_dim // self.num_heads
44+
seqLen, batchSize, embedDim = qOut.shape
45+
headDim = embedDim // self.num_heads
4946

50-
q_out = (
51-
q_out.view(seq_len, batch_size, self.num_heads, head_dim)
47+
qOut = (
48+
qOut.view(seqLen, batchSize, self.num_heads, headDim)
5249
.permute(1, 2, 0, 3)
53-
.reshape(batch_size * self.num_heads, seq_len, head_dim)
50+
.reshape(batchSize * self.num_heads, seqLen, headDim)
5451
)
55-
k_out = (
56-
k_out.view(seq_len, batch_size, self.num_heads, head_dim)
52+
kOut = (
53+
kOut.view(seqLen, batchSize, self.num_heads, headDim)
5754
.permute(1, 2, 0, 3)
58-
.reshape(batch_size * self.num_heads, seq_len, head_dim)
55+
.reshape(batchSize * self.num_heads, seqLen, headDim)
5956
)
60-
v_out = (
61-
v_out.view(seq_len, batch_size, self.num_heads, head_dim)
57+
vOut = (
58+
vOut.view(seqLen, batchSize, self.num_heads, headDim)
6259
.permute(1, 2, 0, 3)
63-
.reshape(batch_size * self.num_heads, seq_len, head_dim)
60+
.reshape(batchSize * self.num_heads, seqLen, headDim)
6461
)
6562

6663
# 3) Scale queries, then quantize
67-
q_scaled = q_out / math.sqrt(head_dim)
68-
q_scaled = self.q_scaled_quant(q_scaled)
64+
qScaled = qOut / math.sqrt(headDim)
65+
qScaled = self.q_scaled_quant(qScaled)
6966

7067
# 4) Transpose + quantize K, compute attention weights
71-
k_t = k_out.transpose(-2, -1)
68+
k_t = kOut.transpose(-2, -1)
7269
k_t = self.k_transposed_quant(k_t)
7370

74-
attn_weights = torch.bmm(q_scaled, k_t)
75-
attn_weights = self.softmax_input_quant(attn_weights)
76-
attn_weights = F.softmax(attn_weights, dim=-1)
77-
attn_weights = self.attn_output_weights_quant(attn_weights)
71+
attnWeights = torch.bmm(qScaled, k_t)
72+
attnWeights = self.softmax_input_quant(attnWeights)
73+
attnWeights = F.softmax(attnWeights, dim=-1)
74+
attnWeights = self.attn_output_weights_quant(attnWeights)
7875

7976
# 5) Quantize V, multiply, reshape back, and final out projection
80-
v_out = self.v_quant(v_out)
81-
attn_output = torch.bmm(attn_weights, v_out)
77+
vOut = self.v_quant(vOut)
78+
attnOutput = torch.bmm(attnWeights, vOut)
8279

83-
attn_output = (
84-
attn_output.view(batch_size, self.num_heads, seq_len, head_dim)
80+
attnOutput = (
81+
attnOutput.view(batchSize, self.num_heads, seqLen, headDim)
8582
.permute(2, 0, 1, 3)
86-
.reshape(seq_len, batch_size, embed_dim)
83+
.reshape(seqLen, batchSize, embedDim)
8784
)
8885

89-
attn_output = self.out_proj(attn_output)
90-
return attn_output
86+
attnOutput = self.out_proj(attnOutput)
87+
return attnOutput
Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
"""
1010

1111
import torch.nn as nn
12-
import torch.fx as fx
1312
from brevitas.fx.brevitas_tracer import (
1413
_symbolic_trace,
1514
_is_brevitas_leaf_module,
@@ -30,8 +29,8 @@ class CustomBrevitasTracer(Tracer):
3029

3130
def __init__(
3231
self,
33-
leaf_classes: Optional[List[Type[nn.Module]]] = None,
34-
non_leaf_classes: Optional[List[Type[nn.Module]]] = None,
32+
leafClasses: Optional[List[Type[nn.Module]]] = None,
33+
nonLeafClasses: Optional[List[Type[nn.Module]]] = None,
3534
debug: bool = False,
3635
) -> None:
3736
"""
@@ -43,31 +42,31 @@ def __init__(
4342
debug: Whether to print debug information during tracing.
4443
"""
4544
super().__init__()
46-
self.leaf_classes = leaf_classes if leaf_classes is not None else []
47-
self.non_leaf_classes = non_leaf_classes if non_leaf_classes is not None else []
45+
self.leafClasses = leafClasses if leafClasses is not None else []
46+
self.nonLeafClasses = nonLeafClasses if nonLeafClasses is not None else []
4847
self.debug = debug
4948

50-
def register_leaf_module(self, module_cls: Type[nn.Module]) -> None:
49+
def registerLeafModule(self, moduleCls: Type[nn.Module]) -> None:
5150
"""
5251
Add a module class to the list of leaf modules.
5352
5453
Args:
5554
module_cls: The module class to register as a leaf module.
5655
"""
57-
if module_cls not in self.leaf_classes:
58-
self.leaf_classes.append(module_cls)
56+
if moduleCls not in self.leafClasses:
57+
self.leafClasses.append(moduleCls)
5958

60-
def register_non_leaf_module(self, module_cls: Type[nn.Module]) -> None:
59+
def registerNonLeafModule(self, moduleCls: Type[nn.Module]) -> None:
6160
"""
6261
Add a module class to the list of non-leaf modules.
6362
6463
Args:
6564
module_cls: The module class to register as a non-leaf module.
6665
"""
67-
if module_cls not in self.non_leaf_classes:
68-
self.non_leaf_classes.append(module_cls)
66+
if moduleCls not in self.nonLeafClasses:
67+
self.nonLeafClasses.append(moduleCls)
6968

70-
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
69+
def is_leaf_module(self, m: nn.Module, moduleQualifiedName: str) -> bool:
7170
"""
7271
Determine whether a module should be treated as a leaf module.
7372
@@ -84,16 +83,16 @@ def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
8483
bool: True if the module should be treated as a leaf module, False otherwise.
8584
"""
8685
# First check explicitly registered classes
87-
if any(isinstance(m, lc) for lc in self.leaf_classes):
86+
if any(isinstance(m, lc) for lc in self.leafClasses):
8887
return True
89-
if any(isinstance(m, nlc) for nlc in self.non_leaf_classes):
88+
if any(isinstance(m, nlc) for nlc in self.nonLeafClasses):
9089
return False
9190
# Fall back to default Brevitas behavior
92-
return _is_brevitas_leaf_module(m, module_qualified_name)
91+
return _is_brevitas_leaf_module(m, moduleQualifiedName)
9392

9493

95-
def custom_brevitas_trace(
96-
root: nn.Module, concrete_args=None, tracer: Optional[CustomBrevitasTracer] = None
94+
def customBrevitasTrace(
95+
root: nn.Module, concreteArgs=None, tracer: Optional[CustomBrevitasTracer] = None
9796
) -> GraphModule:
9897
"""
9998
Create an FX GraphModule using the CustomBrevitasTracer.
@@ -108,4 +107,4 @@ def custom_brevitas_trace(
108107
"""
109108
if tracer is None:
110109
tracer = CustomBrevitasTracer()
111-
return _symbolic_trace(tracer, root, concrete_args)
110+
return _symbolic_trace(tracer, root, concreteArgs)

0 commit comments

Comments
 (0)