1212from typing import List , Optional , Tuple
1313
1414import torch
15- from executorch .examples .models .llama2 .llama_transformer import Attention
15+ from executorch .examples .models .llama .llama_transformer import Attention
1616from torch import nn
1717
1818
@@ -28,7 +28,7 @@ def apply_rotary_emb_single(
2828 return x_out
2929
3030
31- class KVCacheSha (torch .nn .Module ):
31+ class KVCacheSHA (torch .nn .Module ):
3232 def __init__ (
3333 self ,
3434 max_batch_size : int ,
@@ -74,7 +74,7 @@ def get_cache(self, head_idx):
7474 )
7575
7676
77- class SDPASha (torch .nn .Module ):
77+ class SDPASHA (torch .nn .Module ):
7878
7979 def __init__ (
8080 self ,
@@ -89,7 +89,7 @@ def __init__(
8989 self .head_dim = head_dim
9090 self .n_rep = n_rep
9191 self .dim = dim
92- self .kv_cache = KVCacheSha (
92+ self .kv_cache = KVCacheSHA (
9393 max_batch_size , max_seq_length , n_heads // n_rep , head_dim
9494 )
9595 self .scale_factor = math .sqrt (head_dim )
@@ -123,7 +123,7 @@ def forward(
123123 return torch .cat (output , dim = - 1 )
124124
125125
126- class AttentionSha (nn .Module ):
126+ class AttentionSHA (nn .Module ):
127127 def __init__ (self , attention_mha : nn .Module ):
128128 super ().__init__ ()
129129 if not attention_mha .use_kv_cache :
@@ -136,7 +136,7 @@ def __init__(self, attention_mha: nn.Module):
136136 self .max_batch_size = attention_mha .max_batch_size
137137 self .max_seq_len = attention_mha .max_seq_len
138138 self .head_dim = attention_mha .dim // self .n_heads
139- self .SDPA = SDPASha (
139+ self .SDPA = SDPASHA (
140140 self .max_batch_size ,
141141 self .max_seq_len ,
142142 self .n_heads ,
@@ -212,7 +212,7 @@ def replace_attention_to_attention_sha(module: torch.nn.Module):
212212 setattr (
213213 module ,
214214 name ,
215- AttentionSha (child ),
215+ AttentionSHA (child ),
216216 )
217217 else :
218218 replace_attention_to_attention_sha (child )
0 commit comments