@@ -209,6 +209,8 @@ def pattern(
209209 # that share key/value.
210210
211211 key_seq_BHkvTDh = op .Concat (past_key , key_BHkvSDh_rope , axis = - 2 )
212+ # Concat with past_key is optional:
213+ key_seq_BHkvTDh = pattern .OrValue ([key_seq_BHkvTDh , key_BHkvSDh_rope ])
212214 key_seq_BHkv1TDh = op .Unsqueeze (key_seq_BHkvTDh , 2 )
213215 key_seq_BHkvGTDh = op .Expand (key_seq_BHkv1TDh , pattern .ANY_VALUE )
214216 key_seq_BHTDh = op .Reshape (
@@ -218,6 +220,8 @@ def pattern(
218220 # Concatenate past_value cache and current value, expand across heads
219221 # that share key/value.
220222 value_seq_BHkvTDh = op .Concat (past_value , value_BHkvSDh , axis = - 2 )
223+ # Concat with past_value is optional:
224+ value_seq_BHkvTDh = pattern .OrValue ([value_seq_BHkvTDh , value_BHkvSDh ])
221225 value_seq_BHkv1TDh = op .Unsqueeze (value_seq_BHkvTDh , 2 )
222226 value_seq_BHkvGTDh = op .Expand (value_seq_BHkv1TDh , pattern .ANY_VALUE )
223227 value_seq_BHTDh = op .Reshape (
@@ -268,9 +272,9 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
268272 if no_match (value_BSDkv , ["B" , "S" , "Dkv" ]):
269273 return False
270274
271- if no_match (past_key , ["B" , "Hkv" , "P" , "Dh" ]):
275+ if past_key is not None and no_match (past_key , ["B" , "Hkv" , "P" , "Dh" ]):
272276 return False
273- if no_match (past_value , ["B" , "Hkv" , "P" , "Dv" ]):
277+ if past_value is not None and no_match (past_value , ["B" , "Hkv" , "P" , "Dv" ]):
274278 return False
275279
276280 # TODO: verify Reshapes:
0 commit comments