@@ -300,7 +300,7 @@ def get_example_inputs(self):
300300 # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
301301 def get_example_inputs_kvcache_sdpa (self ):
302302 if self .enable_dynamic_shape :
303- return (
303+ args = (
304304 torch .tensor (
305305 [[0 for _ in range (self .static_seq_length )]], dtype = torch .long
306306 ),
@@ -315,18 +315,19 @@ def get_example_inputs_kvcache_sdpa(self):
315315 [0 ], dtype = torch .long
316316 ), # start_pos, what token of output are we on.
317317 )
318- if self .decode_kv_cache_as_io :
319- args = args + (
320- # (n_layers, max_batch_size, n_heads, max_seq_length, head_dim)
321- torch .zeros (self ._cache_shape , dtype = torch .float16 ), # k-cache
322- torch .zeros (self ._cache_shape , dtype = torch .float16 ), # v-cache
323- )
318+
319+ if self .decode_kv_cache_as_io :
320+ args = args + (
321+ # (n_layers, max_batch_size, n_heads, max_seq_length, head_dim)
322+ torch .zeros (self ._cache_shape , dtype = torch .float16 ), # k-cache
323+ torch .zeros (self ._cache_shape , dtype = torch .float16 ), # v-cache
324+ )
324325
325- if self .use_additive_kv_cache_update :
326- args = args + (
327- torch .zeros (self ._cache_pos_mask_shape , dtype = torch .float16 ),
328- )
329- return args
326+ if self .use_additive_kv_cache_update :
327+ args = args + (
328+ torch .zeros (self ._cache_pos_mask_shape , dtype = torch .float16 ),
329+ )
330+ return args
330331
331332 def _transform_for_pre_quantization (self , checkpoint , model_args ):
332333 assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
0 commit comments