@@ -82,13 +82,11 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
8282class SDPASimple (torch .nn .Module ):
8383 def __init__ (
8484 self ,
85- kv_cache : KVCache ,
8685 dim : int ,
8786 head_dim : int ,
8887 n_rep : int ,
8988 ):
9089 super ().__init__ ()
91- self .kv_cache = kv_cache
9290 self .dim = dim
9391 self .head_dim = head_dim
9492 self .n_rep = n_rep
@@ -103,11 +101,6 @@ def forward(
103101 seqlen ,
104102 mask ,
105103 ):
106- q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
107- k = k .transpose (1 , 2 )
108- v = v .transpose (1 , 2 )
109-
110- k , v = self .kv_cache .update (input_pos , k , v )
111104 attn_mask = mask [None , None , input_pos ]
112105
113106 k = k .repeat_interleave (self .n_rep , dim = 1 )
@@ -141,12 +134,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
141134class SDPAFlex (torch .nn .Module ):
142135 def __init__ (
143136 self ,
144- kv_cache : KVCache ,
145137 dim : int ,
146138 n_rep : int ,
147139 ):
148140 super ().__init__ ()
149- self .kv_cache = kv_cache
150141 self .dim = dim
151142 self .n_rep = n_rep
152143
@@ -160,9 +151,10 @@ def forward(
160151 seqlen ,
161152 mask ,
162153 ):
163- q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
164-
165- k , v = self .kv_cache .update (input_pos , k , v )
154+ """
155+ q: (bs, n_heads, seqlen, head_dim)
156+ k, v: (bs, n_local_heads, seqlen, head_dim)
157+ """
166158 k = repeat_kv (k , self .n_rep )
167159 v = repeat_kv (v , self .n_rep )
168160 attn_mask = mask [input_pos ]
@@ -182,7 +174,7 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
182174 setattr (
183175 module ,
184176 name ,
185- SDPASimple (child .kv_cache , child . dim , child .head_dim , child .n_rep ),
177+ SDPASimple (child .dim , child .head_dim , child .n_rep ),
186178 )
187179 else :
188180 replace_sdpa_with_simple_sdpa (child )
@@ -195,7 +187,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
195187 setattr (
196188 module ,
197189 name ,
198- SDPAFlex (child .kv_cache , child . dim , child .n_rep ),
190+ SDPAFlex (child .dim , child .n_rep ),
199191 )
200192 else :
201193 replace_sdpa_with_flex_sdpa (child )
@@ -227,13 +219,11 @@ class SDPACoreML(torch.nn.Module):
227219
228220 def __init__ (
229221 self ,
230- kv_cache : KVCache ,
231222 dim : int ,
232223 head_dim : int ,
233224 n_rep : int ,
234225 ):
235226 super ().__init__ ()
236- self .kv_cache = kv_cache
237227 self .dim = dim
238228 self .head_dim = head_dim
239229 self .n_rep = n_rep
@@ -248,11 +238,6 @@ def forward(
248238 seqlen ,
249239 mask ,
250240 ):
251- q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
252- k = k .transpose (1 , 2 )
253- v = v .transpose (1 , 2 )
254-
255- k , v = self .kv_cache .update (input_pos , k , v )
256241 attn_mask = mask [None , None , input_pos ]
257242
258243 if self .n_rep > 1 :
@@ -270,7 +255,7 @@ def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module):
270255 setattr (
271256 module ,
272257 name ,
273- SDPACoreML (child .kv_cache , child . dim , child .head_dim , child .n_rep ),
258+ SDPACoreML (child .dim , child .head_dim , child .n_rep ),
274259 )
275260 else :
276261 replace_sdpa_with_coreml_sdpa (child )
@@ -357,6 +342,9 @@ def __init__(
357342 def update (
358343 self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
359344 ) -> Tuple [torch .Tensor , torch .Tensor ]:
345+ # can we combine this with KVCacheCoreML?
346+ k_val = k_val .transpose (1 , 2 )
347+ v_val = v_val .transpose (1 , 2 )
360348 k_out = torch .ops .aten .index_put_ (self .past_k_caches , [None , input_pos ], k_val )
361349 v_out = torch .ops .aten .index_put_ (self .past_v_caches , [None , input_pos ], v_val )
362350
0 commit comments