55# Federico Brancasi <[email protected] > 66
77import math
8+ from typing import Optional , Tuple
89
910import torch
1011import torch .nn .functional as F
1112from brevitas .nn .quant_mha import QuantMultiheadAttention
1213from torch import Tensor
1314
1415
15- def mhaForward (
16- self : QuantMultiheadAttention , query : Tensor , key : Tensor , value : Tensor
16+ def _mhaForwardImpl (
17+ self : QuantMultiheadAttention ,
18+ query : Tensor ,
19+ key : Tensor ,
20+ value : Tensor ,
21+ need_transpose_in : bool ,
22+ need_transpose_out : bool ,
1723) -> Tensor :
18- """Explicit, export-friendly MHA forward."""
19- qOut = self .q_proj (query )
20- kOut = self .k_proj (key )
21- vOut = self .v_proj (value )
24+ """Core MHA forward implementation."""
25+ # FBRANCASI: Handle batch_first by transposing if needed
26+ if need_transpose_in :
27+ if key is value :
28+ if query is key :
29+ query = key = value = query .transpose (1 , 0 )
30+ else :
31+ query , key = [x .transpose (1 , 0 ) for x in (query , key )]
32+ value = key
33+ else :
34+ query , key , value = [x .transpose (1 , 0 ) for x in (query , key , value )]
35+
36+ if self .in_proj is not None :
37+ # FBRANCASI: Handle packed projections (default case for models like ViT)
38+ # Only support self-attention where query == key == value
39+ if not (query is key and key is value ):
40+ raise RuntimeError (
41+ "Packed in_proj is supported only for self-attention with k is v is q. Set packed_in_proj=False."
42+ )
43+ qkv = self .in_proj (query )
44+ qkv_tensor = qkv .value if hasattr (qkv , "value" ) else qkv
45+ qOut , kOut , vOut = qkv_tensor .chunk (3 , dim = - 1 )
46+ else :
47+ q_result = self .q_proj (query )
48+ k_result = self .k_proj (key )
49+ v_result = self .v_proj (value )
50+
51+ qOut = q_result .value if hasattr (q_result , "value" ) else q_result
52+ kOut = k_result .value if hasattr (k_result , "value" ) else k_result
53+ vOut = v_result .value if hasattr (v_result , "value" ) else v_result
2254
2355 seqLen , batchSize , embedDim = qOut .shape
2456 headDim = embedDim // self .num_heads
2557
2658 qOut = (
27- qOut .view ( seqLen , batchSize , self . num_heads , headDim )
28- .permute ( 1 , 2 , 0 , 3 )
29- .reshape ( batchSize * self . num_heads , seqLen , headDim )
59+ qOut .contiguous ( )
60+ .view ( seqLen , batchSize * self . num_heads , headDim )
61+ .transpose ( 0 , 1 )
3062 )
3163 kOut = (
32- kOut .view ( seqLen , batchSize , self . num_heads , headDim )
33- .permute ( 1 , 2 , 0 , 3 )
34- .reshape ( batchSize * self . num_heads , seqLen , headDim )
64+ kOut .contiguous ( )
65+ .view ( seqLen , batchSize * self . num_heads , headDim )
66+ .transpose ( 0 , 1 )
3567 )
3668 vOut = (
37- vOut .view ( seqLen , batchSize , self . num_heads , headDim )
38- .permute ( 1 , 2 , 0 , 3 )
39- .reshape ( batchSize * self . num_heads , seqLen , headDim )
69+ vOut .contiguous ( )
70+ .view ( seqLen , batchSize * self . num_heads , headDim )
71+ .transpose ( 0 , 1 )
4072 )
4173
4274 qScaled = qOut / math .sqrt (headDim )
@@ -54,10 +86,65 @@ def mhaForward(
5486 attnOutput = torch .bmm (attnWeights , vOut )
5587
5688 attnOutput = (
57- attnOutput .view (batchSize , self .num_heads , seqLen , headDim )
58- .permute (2 , 0 , 1 , 3 )
59- .reshape (seqLen , batchSize , embedDim )
89+ attnOutput .transpose (0 , 1 ).contiguous ().view (seqLen , batchSize , embedDim )
6090 )
6191
62- attnOutput = self .out_proj (attnOutput )
92+ out_result = self .out_proj (attnOutput )
93+ attnOutput = out_result .value if hasattr (out_result , "value" ) else out_result
94+
95+ if need_transpose_out :
96+ attnOutput = attnOutput .transpose (1 , 0 )
97+
6398 return attnOutput
99+
100+
101+ def mhaForwardBatchFirst (
102+ self : QuantMultiheadAttention ,
103+ query : Tensor ,
104+ key : Tensor ,
105+ value : Tensor ,
106+ need_weights : bool = True ,
107+ ** kwargs ,
108+ ) -> Tuple [Tensor , Optional [Tensor ]]:
109+ """MHA forward for batch_first=True."""
110+ attn_output = _mhaForwardImpl (
111+ self , query , key , value , need_transpose_in = True , need_transpose_out = True
112+ )
113+ # PyTorch always returns a tuple, even when need_weights=False
114+ return (attn_output , None )
115+
116+
117+ def mhaForwardSeqFirst (
118+ self : QuantMultiheadAttention ,
119+ query : Tensor ,
120+ key : Tensor ,
121+ value : Tensor ,
122+ need_weights : bool = True ,
123+ ** kwargs ,
124+ ) -> Tuple [Tensor , Optional [Tensor ]]:
125+ """MHA forward for batch_first=False."""
126+ attn_output = _mhaForwardImpl (
127+ self , query , key , value , need_transpose_in = False , need_transpose_out = False
128+ )
129+ # PyTorch always returns a tuple, even when need_weights=False
130+ return (attn_output , None )
131+
132+
133+ def mhaForward (
134+ self : QuantMultiheadAttention ,
135+ query : Tensor ,
136+ key : Tensor ,
137+ value : Tensor ,
138+ need_weights : bool = True ,
139+ ** kwargs ,
140+ ) -> Tuple [Tensor , Optional [Tensor ]]:
141+ """Explicit, export-friendly MHA forward.
142+
143+ This function will be replaced with the appropriate batch_first or seq_first version
144+ during module transformation based on the module's batch_first attribute.
145+ """
146+ # FBRANCASI: Appropriate version before tracing
147+ if self .batch_first :
148+ return mhaForwardBatchFirst (self , query , key , value , need_weights , ** kwargs )
149+ else :
150+ return mhaForwardSeqFirst (self , query , key , value , need_weights , ** kwargs )
0 commit comments