Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
)

# This reduces dequantize overhead by dequantizing only upto the
# current input position.
self.return_sliced_cache = True
# For now supporting int8 only
self.use_custom_update_cache_op = use_custom_update_cache_op
self.quantized_cache_dtype = torch.int8
Expand Down Expand Up @@ -125,19 +128,34 @@ def update(self, input_pos, k_val, v_val):
self.v_cache_scales[:, input_pos] = v_scales
self.v_cache_zero_points[:, input_pos] = v_zero_points

if self.return_sliced_cache:
k_cache_out = self.k_cache[:, :input_pos+1]
k_cache_scales = self.k_cache_scales[:, :input_pos+1]
k_cache_zero_points = self.k_cache_zero_points[:, :input_pos+1]
v_cache_out = self.v_cache[:, :input_pos+1]
v_cache_scales = self.v_cache_scales[:, :input_pos+1]
v_cache_zero_points = self.v_cache_zero_points[:, :input_pos+1]
else:
k_cache_out = self.k_cache
k_cache_scales = self.k_cache_scales
k_cache_zero_points = self.k_cache_zero_points
v_cache_out = self.v_cache
v_cache_scales = self.v_cache_scales
v_cache_zero_points = self.v_cache_zero_points

k_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.k_cache,
self.k_cache_scales,
self.k_cache_zero_points,
k_cache_out,
k_cache_scales,
k_cache_zero_points,
torch.iinfo(self.quantized_cache_dtype).min,
torch.iinfo(self.quantized_cache_dtype).max,
self.quantized_cache_dtype,
self.cache_fp_type,
)
v_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.v_cache,
self.v_cache_scales,
self.v_cache_zero_points,
v_cache_out,
v_cache_scales,
v_cache_zero_points,
torch.iinfo(self.quantized_cache_dtype).min,
torch.iinfo(self.quantized_cache_dtype).max,
self.quantized_cache_dtype,
Expand Down
Loading