@@ -1230,11 +1230,14 @@ def verify_kv_cache(torch_present):
12301230 else :
12311231 attention_packed_mask = None
12321232 if attention_type == 'gpt2_attention' :
1233- torch_output , torch_present = attention (
1234- input_tensor ,
1235- layer_past = None ,
1236- use_cache = True ,
1237- attention_mask = attention_mask )
1233+ # gpt2 uses DynamicCache
1234+ torch_present = DynamicCache .from_legacy_cache (
1235+ torch_present )
1236+ torch_output = attention (input_tensor ,
1237+ past_key_value = torch_present ,
1238+ use_cache = True ,
1239+ attention_mask = attention_mask )[0 ]
1240+ torch_present = torch_present .to_legacy_cache ()
12381241 elif attention_type == 'llama_attention' :
12391242 position_embeddings = rotary_emb (input_tensor , position_ids )
12401243 attention_mask = attention_mask + AttentionMaskConverter ._make_causal_mask (
@@ -1277,7 +1280,7 @@ def verify_kv_cache(torch_present):
12771280
12781281 torch .cuda .synchronize ()
12791282
1280- if attention_type == 'llama_attention' :
1283+ if attention_type in [ 'llama_attention' , 'gpt2_attention' ] :
12811284 kv_dequant_scale , kv_quant_scale = get_kv_quant_scale (
12821285 torch_present [0 ])
12831286 else :
@@ -1322,7 +1325,7 @@ def verify_kv_cache(torch_present):
13221325 torch_output [:, :in_len // 2 , :].to (
13231326 torch .float32 ).cpu ().numpy (),
13241327 atol = 5e-3 )
1325- if attention_type == 'llama_attention' :
1328+ if attention_type in [ 'llama_attention' , 'gpt2_attention' ] :
13261329 verify_kv_cache (torch_present [0 ])
13271330 else :
13281331 verify_kv_cache (torch_present )
@@ -1374,11 +1377,14 @@ def verify_kv_cache(torch_present):
13741377
13751378 # torch execution
13761379 if attention_type == 'gpt2_attention' :
1377- torch_output , torch_present = attention (
1378- input_tensor ,
1379- layer_past = torch_present ,
1380- use_cache = True ,
1381- attention_mask = attention_mask )
1380+ # gpt2 uses DynamicCache
1381+ torch_present = DynamicCache .from_legacy_cache (
1382+ torch_present )
1383+ torch_output = attention (input_tensor ,
1384+ past_key_value = torch_present ,
1385+ use_cache = True ,
1386+ attention_mask = attention_mask )[0 ]
1387+ torch_present = torch_present .to_legacy_cache ()
13821388 elif attention_type == 'llama_attention' :
13831389 position_embeddings = rotary_emb (input_tensor , position_ids )
13841390 attention_mask = attention_mask + AttentionMaskConverter ._make_causal_mask (
0 commit comments