Skip to content
Merged
21 changes: 21 additions & 0 deletions examples/models/llama2/source_transformation/quantized_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,27 @@ def update(self, input_pos, k_val, v_val):
self.quantized_cache_dtype,
self.cache_fp_type,
)

if self.is_transposed:
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
dim_to_slice = 2 if self.is_transposed else 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k.copy_(k_val)
# pyre-ignore: Incompatible parameter type [6]
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v.copy_(v_val)
else:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
else:
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_quantized_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_quantized_cache(v_val, v_out, start_pos)

return k_out, v_out

@classmethod
Expand Down
Loading