Skip to content

Commit 2f00a67

Browse files
Add New Models and Fix MHA Problems
1 parent 21b9c74 commit 2f00a67

File tree

8 files changed

+541
-24
lines changed

8 files changed

+541
-24
lines changed

DeepQuant/CustomForwards/MultiHeadAttention.py

Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,70 @@
55
# Federico Brancasi <[email protected]>
66

77
import math
8+
from typing import Optional, Tuple
89

910
import torch
1011
import torch.nn.functional as F
1112
from brevitas.nn.quant_mha import QuantMultiheadAttention
1213
from torch import Tensor
1314

1415

15-
def mhaForward(
16-
self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor
16+
def _mhaForwardImpl(
17+
self: QuantMultiheadAttention,
18+
query: Tensor,
19+
key: Tensor,
20+
value: Tensor,
21+
need_transpose_in: bool,
22+
need_transpose_out: bool,
1723
) -> Tensor:
18-
"""Explicit, export-friendly MHA forward."""
19-
qOut = self.q_proj(query)
20-
kOut = self.k_proj(key)
21-
vOut = self.v_proj(value)
24+
"""Core MHA forward implementation."""
25+
# FBRANCASI: Handle batch_first by transposing if needed
26+
if need_transpose_in:
27+
if key is value:
28+
if query is key:
29+
query = key = value = query.transpose(1, 0)
30+
else:
31+
query, key = [x.transpose(1, 0) for x in (query, key)]
32+
value = key
33+
else:
34+
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
35+
36+
if self.in_proj is not None:
37+
# FBRANCASI: Handle packed projections (default case for models like ViT)
38+
# Only support self-attention where query == key == value
39+
if not (query is key and key is value):
40+
raise RuntimeError(
41+
"Packed in_proj is supported only for self-attention with k is v is q. Set packed_in_proj=False."
42+
)
43+
qkv = self.in_proj(query)
44+
qkv_tensor = qkv.value if hasattr(qkv, "value") else qkv
45+
qOut, kOut, vOut = qkv_tensor.chunk(3, dim=-1)
46+
else:
47+
q_result = self.q_proj(query)
48+
k_result = self.k_proj(key)
49+
v_result = self.v_proj(value)
50+
51+
qOut = q_result.value if hasattr(q_result, "value") else q_result
52+
kOut = k_result.value if hasattr(k_result, "value") else k_result
53+
vOut = v_result.value if hasattr(v_result, "value") else v_result
2254

2355
seqLen, batchSize, embedDim = qOut.shape
2456
headDim = embedDim // self.num_heads
2557

2658
qOut = (
27-
qOut.view(seqLen, batchSize, self.num_heads, headDim)
28-
.permute(1, 2, 0, 3)
29-
.reshape(batchSize * self.num_heads, seqLen, headDim)
59+
qOut.contiguous()
60+
.view(seqLen, batchSize * self.num_heads, headDim)
61+
.transpose(0, 1)
3062
)
3163
kOut = (
32-
kOut.view(seqLen, batchSize, self.num_heads, headDim)
33-
.permute(1, 2, 0, 3)
34-
.reshape(batchSize * self.num_heads, seqLen, headDim)
64+
kOut.contiguous()
65+
.view(seqLen, batchSize * self.num_heads, headDim)
66+
.transpose(0, 1)
3567
)
3668
vOut = (
37-
vOut.view(seqLen, batchSize, self.num_heads, headDim)
38-
.permute(1, 2, 0, 3)
39-
.reshape(batchSize * self.num_heads, seqLen, headDim)
69+
vOut.contiguous()
70+
.view(seqLen, batchSize * self.num_heads, headDim)
71+
.transpose(0, 1)
4072
)
4173

4274
qScaled = qOut / math.sqrt(headDim)
@@ -54,10 +86,65 @@ def mhaForward(
5486
attnOutput = torch.bmm(attnWeights, vOut)
5587

5688
attnOutput = (
57-
attnOutput.view(batchSize, self.num_heads, seqLen, headDim)
58-
.permute(2, 0, 1, 3)
59-
.reshape(seqLen, batchSize, embedDim)
89+
attnOutput.transpose(0, 1).contiguous().view(seqLen, batchSize, embedDim)
6090
)
6191

62-
attnOutput = self.out_proj(attnOutput)
92+
out_result = self.out_proj(attnOutput)
93+
attnOutput = out_result.value if hasattr(out_result, "value") else out_result
94+
95+
if need_transpose_out:
96+
attnOutput = attnOutput.transpose(1, 0)
97+
6398
return attnOutput
99+
100+
101+
def mhaForwardBatchFirst(
102+
self: QuantMultiheadAttention,
103+
query: Tensor,
104+
key: Tensor,
105+
value: Tensor,
106+
need_weights: bool = True,
107+
**kwargs,
108+
) -> Tuple[Tensor, Optional[Tensor]]:
109+
"""MHA forward for batch_first=True."""
110+
attn_output = _mhaForwardImpl(
111+
self, query, key, value, need_transpose_in=True, need_transpose_out=True
112+
)
113+
# PyTorch always returns a tuple, even when need_weights=False
114+
return (attn_output, None)
115+
116+
117+
def mhaForwardSeqFirst(
118+
self: QuantMultiheadAttention,
119+
query: Tensor,
120+
key: Tensor,
121+
value: Tensor,
122+
need_weights: bool = True,
123+
**kwargs,
124+
) -> Tuple[Tensor, Optional[Tensor]]:
125+
"""MHA forward for batch_first=False."""
126+
attn_output = _mhaForwardImpl(
127+
self, query, key, value, need_transpose_in=False, need_transpose_out=False
128+
)
129+
# PyTorch always returns a tuple, even when need_weights=False
130+
return (attn_output, None)
131+
132+
133+
def mhaForward(
134+
self: QuantMultiheadAttention,
135+
query: Tensor,
136+
key: Tensor,
137+
value: Tensor,
138+
need_weights: bool = True,
139+
**kwargs,
140+
) -> Tuple[Tensor, Optional[Tensor]]:
141+
"""Explicit, export-friendly MHA forward.
142+
143+
This function will be replaced with the appropriate batch_first or seq_first version
144+
during module transformation based on the module's batch_first attribute.
145+
"""
146+
# FBRANCASI: Appropriate version before tracing
147+
if self.batch_first:
148+
return mhaForwardBatchFirst(self, query, key, value, need_weights, **kwargs)
149+
else:
150+
return mhaForwardSeqFirst(self, query, key, value, need_weights, **kwargs)

DeepQuant/Pipeline/DequantUnify.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def mergeDequants(
8080

8181
# FBRANCASI: Check output equivalence with a warning instead of error
8282
if checkEquivalence:
83-
if not torch.allclose(referenceOutput, output, atol=1e-5) and debug:
83+
# FBRANCASI: Handle case where output/referenceOutput might be tuples
84+
refToCompare = referenceOutput[0] if isinstance(referenceOutput, tuple) else referenceOutput
85+
outToCompare = output[0] if isinstance(output, tuple) else output
86+
if not torch.allclose(refToCompare, outToCompare, atol=1e-5) and debug:
8487
print(
8588
cc.warning(
8689
"Modification of Dequant Nodes may have changed the output slightly"

DeepQuant/Pipeline/Injection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def injectCustomForwards(
5151
output = fxModel(exampleInput)
5252

5353
if checkEquivalence:
54-
if torch.allclose(referenceOutput, output, atol=1e-5):
54+
# Handle case where output might be a tuple (e.g., from MHA)
55+
outputToCompare = output[0] if isinstance(output, tuple) else output
56+
if torch.allclose(referenceOutput, outputToCompare, atol=1e-5):
5557
if debug:
5658
print(cc.success("Injection of New Modules: output is consistent"))
5759
else:

DeepQuant/Pipeline/QuantSplit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ def splitQuantNodes(
4646
output = splitModel(exampleInput)
4747

4848
if checkEquivalence:
49-
if torch.allclose(referenceOutput, output, atol=1e-5):
49+
# FBRANCASI: Handle case where output/referenceOutput might be tuples
50+
refToCompare = referenceOutput[0] if isinstance(referenceOutput, tuple) else referenceOutput
51+
outToCompare = output[0] if isinstance(output, tuple) else output
52+
if torch.allclose(refToCompare, outToCompare, atol=1e-5):
5053
if debug:
5154
print(cc.success("Split of Quant Nodes: output is consistent"))
5255
else:

DeepQuant/QuantManipulation/QuantNodesDivider.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,18 @@ def convertQuantOperations(
110110
newCatArgs[0] = updatedTensors
111111
userNode.args = tuple(newCatArgs)
112112
usersUpdated = True
113+
elif (
114+
userNode.op == "call_function"
115+
and userNode.target == getattr
116+
and len(userNode.args) >= 2
117+
and userNode.args[0] is node
118+
and userNode.args[1] == "value"
119+
):
120+
# FBRANCASI: Special handling for .value access on dequant output
121+
# Replace getattr(dequant_node, 'value') with just dequant_node
122+
userNode.replace_all_uses_with(dequantNode)
123+
nodesToRemove.append(userNode)
124+
usersUpdated = True
113125
else:
114126
# FBRANCASI: Standard node reference replacement
115127
newArgs = []

DeepQuant/Transforms/Transformations.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
from brevitas.nn.quant_mha import QuantMultiheadAttention
1515

1616
from DeepQuant.CustomForwards.Activations import WrapperActivation, activationForward
17+
from DeepQuant.CustomForwards.MultiHeadAttention import (
18+
mhaForwardBatchFirst,
19+
mhaForwardSeqFirst,
20+
)
1721
from DeepQuant.CustomForwards.WBIOL import WBIOLForward, WrapperWBIOL
18-
from DeepQuant.CustomForwards.MultiHeadAttention import mhaForward
1922
from DeepQuant.Transforms.Base import TransformationPass
2023
from DeepQuant.Utils.CustomTracer import QuantTracer
2124

@@ -82,7 +85,11 @@ def injectForward(
8285
self, module: nn.Module, tracer: Optional[QuantTracer] = None
8386
) -> None:
8487
"""Inject custom forward for multi-head attention layers."""
85-
module.forward = mhaForward.__get__(module)
88+
# Select the appropriate forward function based on batch_first
89+
if module.batch_first:
90+
module.forward = mhaForwardBatchFirst.__get__(module)
91+
else:
92+
module.forward = mhaForwardSeqFirst.__get__(module)
8693

8794
if tracer:
8895
tracer.registerNonLeafModule(QuantMultiheadAttention)

0 commit comments

Comments
 (0)