Skip to content

Commit 41ff413

Browse files
committed
Add kv cache args in get_example_inputs_kvcache_sdpa
1 parent de10852 commit 41ff413

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

examples/models/llama/model.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def get_example_inputs(self):
300300
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
301301
def get_example_inputs_kvcache_sdpa(self):
302302
if self.enable_dynamic_shape:
303-
return (
303+
args = (
304304
torch.tensor(
305305
[[0 for _ in range(self.static_seq_length)]], dtype=torch.long
306306
),
@@ -315,18 +315,19 @@ def get_example_inputs_kvcache_sdpa(self):
315315
[0], dtype=torch.long
316316
), # start_pos, what token of output are we on.
317317
)
318-
if self.decode_kv_cache_as_io:
319-
args = args + (
320-
# (n_layers, max_batch_size, n_heads, max_seq_length, head_dim)
321-
torch.zeros(self._cache_shape, dtype=torch.float16), # k-cache
322-
torch.zeros(self._cache_shape, dtype=torch.float16), # v-cache
323-
)
318+
319+
if self.decode_kv_cache_as_io:
320+
args = args + (
321+
# (n_layers, max_batch_size, n_heads, max_seq_length, head_dim)
322+
torch.zeros(self._cache_shape, dtype=torch.float16), # k-cache
323+
torch.zeros(self._cache_shape, dtype=torch.float16), # v-cache
324+
)
324325

325-
if self.use_additive_kv_cache_update:
326-
args = args + (
327-
torch.zeros(self._cache_pos_mask_shape, dtype=torch.float16),
328-
)
329-
return args
326+
if self.use_additive_kv_cache_update:
327+
args = args + (
328+
torch.zeros(self._cache_pos_mask_shape, dtype=torch.float16),
329+
)
330+
return args
330331

331332
def _transform_for_pre_quantization(self, checkpoint, model_args):
332333
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"

0 commit comments

Comments
 (0)