11import torch
2- from torch import nn
3-
4- from executorch .examples .models .llama .norm import RMSNorm
52from executorch .examples .models .llama .attention import ForwardOptions
63from executorch .examples .models .llama .feed_forward import FeedForward
74
5+ from executorch .examples .models .llama .norm import RMSNorm
6+ from torch import nn
7+
88
99class ShortConv (nn .Module ):
1010 def __init__ (
@@ -61,10 +61,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6161 # So, assuming prefill is done on an empty cache, concatenating conv_state to the beginning of the sequence acts similary to
6262 ## using nn.Conv1d(padding=L_cache-1) (for prefill) without no manual padding.
6363 ## However, the manual padding has the added benefit of being correct during decode, when the cache is not initialized to 0.
64- Bx = torch .cat ([self .conv_state , Bx ], dim = - 1 ) # (batch_size, dim, seq_len + L_cache - 1)
64+ Bx = torch .cat (
65+ [self .conv_state , Bx ], dim = - 1
66+ ) # (batch_size, dim, seq_len + L_cache - 1)
6567
6668 ## Update the conv_state
67- new_conv_state = Bx [..., - (self .conv .weight .size (- 1 ) - 1 ) :] # (batch_size, dim, L_cache - 1)
69+ new_conv_state = Bx [
70+ ..., - (self .L_cache - 1 ) :
71+ ] # (batch_size, dim, L_cache - 1)
6872 with torch .no_grad ():
6973 self .conv_state .copy_ (new_conv_state )
7074
@@ -83,15 +87,20 @@ def reset_cache(self):
8387class ShortConvBlock (nn .Module ):
8488 def __init__ (self , dim : int , hidden_dim : int , norm_eps : float ):
8589 super ().__init__ ()
86- # hardcode 3 for now
87- L_cache = 3
88- self .conv = ShortConv (dim , L_cache , bias = False )
90+ self .L_cache = 3 # hardcode 3 for now
91+ self .conv = ShortConv (dim , self .L_cache , bias = False )
8992 self .feed_forward = FeedForward (dim , hidden_dim )
9093 self .ffn_norm = RMSNorm (dim , norm_eps )
9194 # use attention_norm norm instead of operator_norm to unify with TransformerBlock
9295 self .attention_norm = RMSNorm (dim , norm_eps )
9396
94- def forward (self , x , freqs_cos = None , freqs_sin = None , _unused_attn_options : ForwardOptions = None ): # x: 1xN
97+ def forward (
98+ self ,
99+ x ,
100+ freqs_cos = None ,
101+ freqs_sin = None ,
102+ _unused_attn_options : ForwardOptions = None ,
103+ ): # x: 1xN
95104 h = self .conv .forward (self .attention_norm (x ))
96105 h = x + h
97106 out = h + self .feed_forward (self .ffn_norm (h ))
0 commit comments