1212from typing import Tuple , Union
1313
1414import torch
15+ import torch .nn .functional as F
1516
16- from executorch .examples .models .llama .llama_transformer import KVCache , SDPA
17+ from executorch .examples .models .llama .llama_transformer import KVCache , SDPA , FeedForward
1718from executorch .examples .models .llama .source_transformation .quantized_kv_cache import (
1819 QuantizedKVCache ,
1920)
@@ -171,12 +172,14 @@ def __init__(
171172 self ,
172173 kv_cache : KVCache ,
173174 dim : int ,
175+ head_dim : int ,
174176 n_rep : int ,
175177 ):
176178 super ().__init__ ()
177179 self .kv_cache = kv_cache
178180 self .dim = dim
179181 self .n_rep = n_rep
182+ self .scale_factor = math .sqrt (head_dim )
180183
181184 def forward (
182185 self ,
@@ -195,8 +198,7 @@ def forward(
195198 v = repeat_kv (v , self .n_rep )
196199 attn_mask = mask [input_pos ]
197200
198- scale_factor = 1 / math .sqrt (q .size (- 1 ))
199- attn_weight = q @ k .transpose (- 2 , - 1 ) * scale_factor
201+ attn_weight = q @ k .transpose (- 2 , - 1 ) / self .scale_factor
200202 attn_weight += attn_mask
201203 attn_weight = torch .softmax (attn_weight , dim = - 1 )
202204 y = attn_weight @ v
@@ -223,7 +225,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
223225 setattr (
224226 module ,
225227 name ,
226- SDPAFlex (child .kv_cache , child .dim , child .n_rep ),
228+ SDPAFlex (child .kv_cache , child .dim , child .head_dim , child . n_rep ),
227229 )
228230 else :
229231 replace_sdpa_with_flex_sdpa (child )
@@ -428,3 +430,50 @@ def replace_causal_mask(module: torch.nn.Module):
428430 for _ , child in module .named_children ():
429431 replace_causal_mask (child )
430432 return module
433+
434+ class FeedForwardConv2D (torch .nn .Module ):
435+ def __init__ (self , w1 : torch .nn .Linear , w2 : torch .nn .Linear , w3 : torch .nn .Linear ):
436+ super ().__init__ ()
437+ self .w1_conv = torch .nn .Conv2d (
438+ in_channels = w1 .weight .shape [1 ],
439+ out_channels = w1 .weight .shape [0 ],
440+ kernel_size = 1 ,
441+ padding = 0 ,
442+ bias = False ,
443+ )
444+ self .w2_conv = torch .nn .Conv2d (
445+ in_channels = w2 .weight .shape [1 ],
446+ out_channels = w2 .weight .shape [0 ],
447+ kernel_size = 1 ,
448+ padding = 0 ,
449+ bias = False ,
450+ )
451+ self .w3_conv = torch .nn .Conv2d (
452+ in_channels = w3 .weight .shape [1 ],
453+ out_channels = w3 .weight .shape [0 ],
454+ kernel_size = 1 ,
455+ padding = 0 ,
456+ bias = False ,
457+ )
458+
459+ self .w1_conv .weight = torch .nn .Parameter (w1 .weight .reshape (* w1 .weight .shape , 1 , 1 ))
460+ self .w2_conv .weight = torch .nn .Parameter (w2 .weight .reshape (* w2 .weight .shape , 1 , 1 ))
461+ self .w3_conv .weight = torch .nn .Parameter (w3 .weight .reshape (* w3 .weight .shape , 1 , 1 ))
462+
463+
464+ def forward (self , x ):
465+ rank = x .dim ()
466+ x = x .unsqueeze (- 1 ) if rank == 3 else x .reshape (1 , * x .shape , 1 )
467+ x = torch .transpose (x , 1 , 2 )
468+ res = self .w2_conv (F .silu (self .w1_conv (x )) * self .w3_conv (x ))
469+ res = torch .transpose (res , 1 , 2 )
470+ res = res .squeeze (- 1 ) if rank == 3 else res .reshape (* res .shape [1 :3 ])
471+ return res
472+
473+ def replace_feedforward_to_conv2d (module : torch .nn .Module ):
474+ for name , child in module .named_children ():
475+ if isinstance (child , FeedForward ):
476+ setattr (module , name , FeedForwardConv2D (child .w1 , child .w2 , child .w3 ))
477+ else :
478+ replace_feedforward_to_conv2d (child )
479+ return module
0 commit comments