44#
55# Federico Brancasi <[email protected] > 66
7-
87import pytest
98import torch
109import torch .nn as nn
2120
2221
2322class QuantMHSANet (nn .Module ):
23+ """Simple quantized network with multi-head self-attention."""
2424
25- def __init__ (self , embed_dim : int , num_heads : int ) -> None :
26- """
27- Args:
28- embed_dim: The dimension of each embedding vector.
29- num_heads: The number of attention heads.
30- """
25+ def __init__ (self , embedDim : int , numHeads : int ) -> None :
3126 super ().__init__ ()
3227 self .inputQuant = qnn .QuantIdentity (return_quant_tensor = True )
3328 self .mha = qnn .QuantMultiheadAttention (
34- embed_dim = embed_dim ,
35- num_heads = num_heads ,
29+ embed_dim = embedDim ,
30+ num_heads = numHeads ,
3631 dropout = 0.0 ,
3732 bias = True ,
38- packed_in_proj = False , # separate Q, K, V
39- batch_first = False , # expects (sequence, batch, embed_dim)
33+ packed_in_proj = False , # FBRANCASI: separate Q, K, V
34+ batch_first = False , # FBRANCASI: expects (sequence, batch, embed_dim)
4035 in_proj_input_quant = Int8ActPerTensorFloat ,
4136 in_proj_weight_quant = Int8WeightPerTensorFloat ,
4237 in_proj_bias_quant = Int32Bias ,
@@ -51,27 +46,14 @@ def __init__(self, embed_dim: int, num_heads: int) -> None:
5146 )
5247
5348 def forward (self , x : Tensor ) -> Tensor :
54- """
55- Forward pass that first quantizes the input, then applies multi-head attention.
56-
57- Args:
58- x: Input tensor of shape [sequence_len, batch_size, embed_dim].
59-
60- Returns:
61- A tuple (output, None) as per the Brevitas MHA API, where output has shape
62- [sequence_len, batch_size, embed_dim].
63- """
6449 x = self .inputQuant (x )
6550 out = self .mha (x , x , x )
6651 return out
6752
6853
6954@pytest .mark .SingleLayerTests
7055def deepQuantTestMHSA () -> None :
71-
7256 torch .manual_seed (42 )
73-
74- model = QuantMHSANet (embed_dim = 16 , num_heads = 4 ).eval ()
57+ model = QuantMHSANet (embedDim = 16 , numHeads = 4 ).eval ()
7558 sampleInput = torch .randn (10 , 2 , 16 )
76-
7759 exportQuantModel (model , sampleInput )
0 commit comments