Skip to content
Merged
63 changes: 42 additions & 21 deletions examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
from enum import Enum
from typing import Tuple
from typing import Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -93,7 +93,7 @@ def _quantize(self, value):
)
return quantized_value, scales, zero_points

def _quantize_and_update(self, input_pos, k_val, v_val):
def _quantize_and_update(self, input_pos, k_val, v_val, indices=None):
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)

Expand All @@ -104,26 +104,37 @@ def _quantize_and_update(self, input_pos, k_val, v_val):

if self.use_custom_update_cache_op:
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
quantized_k_val, self.k_cache, start_pos, indices
)
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
k_scales, self.k_cache_scales, start_pos, indices
)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos, indices
)
_ = torch.ops.llama.update_cache(
quantized_v_val, self.v_cache, start_pos, indices
)
_ = torch.ops.llama.update_cache(
v_scales, self.v_cache_scales, start_pos, indices
)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos, indices
)
else:
assert indices is None, "Indices not supported for this path"
# Following is also broken because in prefill input_pos = [0]
# but we need to update some slice of cache
self.k_cache[:, input_pos] = quantized_k_val
self.k_cache_scales[:, input_pos] = k_scales
self.k_cache_zero_points[:, input_pos] = k_zero_points
self.v_cache[:, input_pos] = quantized_v_val
self.v_cache_scales[:, input_pos] = v_scales
self.v_cache_zero_points[:, input_pos] = v_zero_points

def _update_and_return_float_values(self, input_pos, k_val, v_val):
self._quantize_and_update(input_pos, k_val, v_val)
def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None):
self._quantize_and_update(input_pos, k_val, v_val, indices)

k_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.k_cache,
Expand All @@ -144,24 +155,26 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val):
self.cache_fp_type,
)

# When returning float values we jsut use the last value
# When returning float values we just use the last value
# instead of dequantized value.
start_pos = input_pos[0].item()
if self.use_custom_update_cache_op:
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos, indices)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos, indices)
else:
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val

return k_out, v_out

def _update_and_return_quantized_values(self, input_pos, k_val, v_val):
self._quantize_and_update(input_pos, k_val, v_val)
def _update_and_return_quantized_values(
self, input_pos, k_val, v_val, indices=None
):
self._quantize_and_update(input_pos, k_val, v_val, indices)

return self.k_cache, self.v_cache

def update(self, input_pos, k_val, v_val):
def update(self, input_pos, k_val, v_val, indices=None):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
Expand All @@ -172,10 +185,12 @@ def update(self, input_pos, k_val, v_val):
v_val = v_val.transpose(1, 2)

if self.return_float_values:
k_out, v_out = self._update_and_return_float_values(input_pos, k_val, v_val)
k_out, v_out = self._update_and_return_float_values(
input_pos, k_val, v_val, indices
)
else:
k_out, v_out = self._update_and_return_quantized_values(
input_pos, k_val, v_val
input_pos, k_val, v_val, indices
)
return k_out.transpose(1, 2), v_out.transpose(1, 2)

Expand Down Expand Up @@ -277,14 +292,20 @@ def __init__(
)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
self,
input_pos: torch.Tensor,
k_val: torch.Tensor,
v_val: torch.Tensor,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D]
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)

_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, indices)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, indices)

return (
self.k_cache.transpose(1, 2),
self.v_cache.transpose(1, 2),
Expand Down
38 changes: 27 additions & 11 deletions extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _validate_update_cache_params(
value,
cache,
start_pos,
indices=None,
):
seq_len = value.size(1)
assert (
Expand All @@ -200,29 +201,44 @@ def _validate_update_cache_params(
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"

torch._check_is_size(start_pos)
# Setting to arbitrary limit of 256 for now since there is no way
# to plumb this information from model config
torch._check(start_pos < cache.size(1))
assert start_pos < cache.size(
1
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"

torch._check((start_pos + seq_len) < cache.size(1))
assert (start_pos + seq_len) < cache.size(
1
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
if indices is None:
torch._check(start_pos < cache.size(1))
assert start_pos < cache.size(
1
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"

torch._check((start_pos + seq_len) < cache.size(1))
assert (start_pos + seq_len) < cache.size(
1
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"

if indices is not None:
assert (
indices.dim() == 2
), f"Expected indices to be 2 dimensional but got {indices.dim()} dimensions."
assert (
indices.dtype == torch.int64
), f"Expected indices to be int64 but got {indices.dtype}"
assert indices.size(0) == value.size(
0
), f"Expected indices batch dimension to match value batch dimension but got {indices.size(0)} and {value.size(0)}"
assert indices.size(1) == value.size(
1
), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(1)}"


@impl(custom_ops_lib, "update_cache", "Meta")
def update_cache_meta(
value,
cache,
start_pos,
indices=None,
):
_validate_update_cache_params(
value,
cache,
start_pos,
indices,
)

# Update cache doesnt really return anything but I dont know a better
Expand Down
20 changes: 12 additions & 8 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,14 @@ Tensor& update_cache_out_no_context(
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
const std::optional<Tensor> indices,
Tensor& output);

at::Tensor update_cache_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos);
const int64_t start_pos,
const std::optional<at::Tensor>& indices);

Tensor& sdpa_with_kv_cache_out_no_context(
const Tensor& q_projected,
Expand Down Expand Up @@ -324,19 +326,21 @@ Tensor& update_cache_out_no_context(
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
const std::optional<Tensor> indices,
Tensor& output) {
executorch::aten::RuntimeContext context{};
return torch::executor::native::update_cache_out(
context, value, cache, start_pos, output);
context, value, cache, start_pos, indices, output);
}

at::Tensor update_cache_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos) {
const int64_t start_pos,
const std::optional<at::Tensor>& indices) {
auto output = at::empty({1});
WRAP_TO_ATEN(update_cache_out_no_context, 3)
(value, cache, start_pos, output);
WRAP_TO_ATEN(update_cache_out_no_context, 4)
(value, cache, start_pos, indices, output);
return output;
}

Expand All @@ -363,10 +367,10 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"update_cache(Tensor value, Tensor(a!) cache, "
"SymInt start_pos) -> Tensor");
"SymInt start_pos, Tensor? indices=None) -> Tensor");
m.def(
"update_cache.out(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
"SymInt start_pos, Tensor? indices=None, *, Tensor(b!) out) -> Tensor(b!)");
m.def(
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
Expand Down Expand Up @@ -396,7 +400,7 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl("update_cache", torch::executor::native::update_cache_aten);
m.impl(
"update_cache.out",
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 4));
m.impl(
"custom_quantized_sdpa",
torch::executor::native::custom_quantized_sdpa_aten);
Expand Down
Loading
Loading