Skip to content

Commit 20307f2

Browse files
committed
Add update_quantized_cache op
Pull Request resolved: #5527 Why? - ton of copies due to functionalization - mutable buffer support without such custom inplace ops will results in giant copies at the end - Making inplace ops work will likely take longer and not clear safe path ghstack-source-id: 245150346 @exported-using-ghexport Differential Revision: [D62301838](https://our.internmc.facebook.com/intern/diff/D62301838/)
1 parent eab2e2d commit 20307f2

File tree

7 files changed

+480
-2
lines changed

7 files changed

+480
-2
lines changed

extension/llm/custom_ops/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@ runtime.python_test(
2222
],
2323
)
2424

25+
runtime.python_test(
26+
name = "test_update_quantized_cache",
27+
srcs = [
28+
"test_update_quantized_cache.py",
29+
],
30+
preload_deps = [
31+
":custom_ops_aot_lib",
32+
],
33+
deps = [
34+
"//caffe2:torch",
35+
],
36+
)
37+
2538
runtime.python_test(
2639
name = "test_preprocess_custom_ops",
2740
srcs = [

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
1010
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
1111
#include <executorch/extension/llm/custom_ops/op_sdpa.h>
12+
#include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
1213

1314
#include <torch/library.h>
1415

1516
namespace torch {
1617
namespace executor {
1718

1819
namespace native {
19-
namespace {
2020
Tensor& sdpa_with_kv_cache_out_no_context(
2121
const Tensor& q_projected,
2222
const Tensor& k_projected,
@@ -81,7 +81,27 @@ at::Tensor sdpa_with_kv_cache_aten(
8181
output);
8282
return output;
8383
}
84-
} // namespace
84+
85+
Tensor& update_quantized_cache_out_no_context(
86+
const Tensor& value,
87+
Tensor& cache,
88+
const int64_t start_pos,
89+
Tensor& output) {
90+
exec_aten::RuntimeContext context{};
91+
return torch::executor::native::update_quantized_cache_out(
92+
context, value, cache, start_pos, output);
93+
}
94+
95+
at::Tensor update_quantized_cache_aten(
96+
const at::Tensor& value,
97+
at::Tensor& cache,
98+
const int64_t start_pos) {
99+
auto output = at::empty({1});
100+
WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3)
101+
(value, cache, start_pos, output);
102+
return output;
103+
}
104+
85105
} // namespace native
86106
} // namespace executor
87107
} // namespace torch
@@ -95,6 +115,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
95115
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
96116
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
97117
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
118+
m.def(
119+
"update_quantized_cache(Tensor value, Tensor(a!) cache, "
120+
"SymInt start_pos) -> Tensor");
121+
m.def(
122+
"update_quantized_cache.out(Tensor value, Tensor(a!) cache, "
123+
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
98124
}
99125

100126
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
@@ -105,3 +131,14 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
105131
WRAP_TO_ATEN(
106132
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
107133
}
134+
135+
// TODO: Rename this file to op_custom_ops_aot.cpp
136+
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
137+
m.impl(
138+
"update_quantized_cache",
139+
torch::executor::native::update_quantized_cache_aten);
140+
m.impl(
141+
"update_quantized_cache.out",
142+
WRAP_TO_ATEN(
143+
torch::executor::native::update_quantized_cache_out_no_context, 3));
144+
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
10+
11+
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
12+
// @lint-ignore CLANGTIDY facebook-unused-include-check
13+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14+
15+
#include <array>
16+
17+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
18+
19+
namespace torch {
20+
namespace executor {
21+
22+
namespace native {
23+
24+
namespace {
25+
bool validate_cache_params(
26+
const Tensor& quantized_value,
27+
const Tensor& quantized_cache,
28+
int64_t start_pos,
29+
int64_t seq_length) {
30+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
31+
quantized_cache.dim() == 4, "quantized cache must be a 4D tensor");
32+
33+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
34+
quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
35+
36+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
37+
start_pos < quantized_cache.size(1),
38+
"start_pos must be less than cache size at dim 1");
39+
40+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
41+
(start_pos + seq_length) <= quantized_cache.size(1),
42+
"start_post + seq_length must be less than max seq length supported by cache."
43+
"start pos: %" PRId64 ", seq_length: %" PRId64
44+
"."
45+
"cache size: %zd",
46+
start_pos,
47+
seq_length,
48+
quantized_cache.size(1));
49+
50+
// Make sure they are in contiguous dim order
51+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
52+
is_contiguous_dim_order(
53+
quantized_cache.dim_order().data(), quantized_cache.dim()),
54+
"quantized cache must be in contiguous dim order");
55+
56+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
57+
is_contiguous_dim_order(
58+
quantized_value.dim_order().data(), quantized_value.dim()),
59+
"quantized value must be in contiguous dim order");
60+
61+
return true;
62+
}
63+
} // anonymous namespace
64+
65+
Tensor& update_quantized_cache_out(
66+
RuntimeContext& ctx,
67+
const Tensor& value,
68+
Tensor& cache,
69+
const int64_t start_pos,
70+
Tensor& output) {
71+
(void)ctx;
72+
int64_t seq_len = value.size(1);
73+
ET_KERNEL_CHECK(
74+
ctx,
75+
validate_cache_params(value, cache, start_pos, seq_len),
76+
InvalidArgument,
77+
output);
78+
79+
ET_CHECK_MSG(
80+
value.size(0) == cache.size(0),
81+
"projected_value batch size should be equal to the cache batch size.");
82+
ET_CHECK_MSG(
83+
value.size(2) == cache.size(2),
84+
"projected_value number of heads should be equal to the cache number of heads.");
85+
ET_CHECK_MSG(
86+
value.size(3) == cache.size(3),
87+
"projected_value embedding dimension should be equal to the cache embedding dimension.");
88+
ET_CHECK_MSG(
89+
value.element_size() == cache.element_size(),
90+
"projected_value data type size should be equal to the cache data type size.");
91+
92+
ET_CHECK_MSG(
93+
is_contiguous_dim_order(value.dim_order().data(), value.dim()),
94+
"projected value must be in contiguous dim order");
95+
ET_CHECK_MSG(
96+
is_contiguous_dim_order(cache.dim_order().data(), cache.dim()),
97+
"projected value must be in contiguous dim order");
98+
99+
const void* value_data = value.const_data_ptr();
100+
void* cache_data = cache.mutable_data_ptr();
101+
102+
ET_CHECK_MSG(value_data, "projected_value data is null");
103+
ET_CHECK_MSG(cache_data, "cache data is null");
104+
105+
auto cache_strides = cache.strides();
106+
exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
107+
exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];
108+
109+
auto value_strides = value.strides();
110+
exec_aten::StridesType value_batch_dim_stride = value_strides[0];
111+
112+
exec_aten::SizesType num_bytes_to_copy =
113+
(value.numel() / value.size(0)) * value.element_size();
114+
115+
for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) {
116+
exec_aten::SizesType cache_pos_offset =
117+
(batch_line * cache_batch_dim_stride +
118+
start_pos * cache_seq_dim_stride) *
119+
cache.element_size();
120+
exec_aten::SizesType value_pos_offset =
121+
(batch_line * value_batch_dim_stride) * cache.element_size();
122+
123+
std::memcpy(
124+
(uint8_t*)cache_data + cache_pos_offset,
125+
(uint8_t*)value_data + value_pos_offset,
126+
num_bytes_to_copy);
127+
}
128+
129+
// Noone uses output. Just a placeholder.
130+
return output;
131+
}
132+
} // namespace native
133+
} // namespace executor
134+
} // namespace torch
135+
136+
// Really this is just an inplace tensor update op
137+
// which makes assumption on the rank of a tensor,
138+
// and the dim order (memory layout) of the tensor.
139+
// Furthermore assumes that the indexing is along
140+
// sequence dimension (dim 1) of the tensor.
141+
// In later diffs will rename this to update_cache.
142+
EXECUTORCH_LIBRARY(
143+
llama,
144+
"update_quantized_cache.out",
145+
torch::executor::native::update_quantized_cache_out);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
namespace native {
17+
18+
Tensor& update_quantized_cache_out(
19+
RuntimeContext& ctx,
20+
const Tensor& value,
21+
Tensor& cache,
22+
const int64_t start_pos,
23+
Tensor& output);
24+
} // namespace native
25+
} // namespace executor
26+
} // namespace torch

extension/llm/custom_ops/sdpa_with_kv_cache.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from torch.library import impl
1919

20+
# TODO rename this file to custom_ops_meta_registration.py
2021
try:
2122
op = torch.ops.llama.sdpa_with_kv_cache.default
2223
assert op is not None
@@ -138,3 +139,54 @@ def fast_hadamard_transform_meta(mat):
138139
# assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!")
139140
# assert(mat.is_contiguous(), "input matrix must be contiguous currently!")
140141
return torch.empty_like(mat)
142+
143+
144+
def _validate_update_cache_params(
145+
value,
146+
cache,
147+
start_pos,
148+
):
149+
seq_len = value.size(1)
150+
assert (
151+
value.dim() == 4
152+
), f"Expected value to be 4 dimensional but got {value.dim()} dimensions."
153+
154+
assert (
155+
value.dtype == cache.dtype
156+
), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}"
157+
158+
for i in [0, 2, 3]:
159+
assert value.size(i) == cache.size(
160+
i
161+
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"
162+
163+
torch._check_is_size(start_pos)
164+
# Setting to arbitrary limit of 256 for now since there is no way
165+
# to plumb this information from model config
166+
torch._check(start_pos < cache.size(1))
167+
assert start_pos < cache.size(
168+
1
169+
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
170+
171+
torch._check((start_pos + seq_len) < cache.size(1))
172+
assert (start_pos + seq_len) < cache.size(
173+
1
174+
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
175+
176+
177+
@impl(custom_ops_lib, "update_quantized_cache", "Meta")
178+
def update_quantized_cache_meta(
179+
value,
180+
cache,
181+
start_pos,
182+
):
183+
_validate_update_cache_params(
184+
value,
185+
cache,
186+
start_pos,
187+
)
188+
189+
# Update cache doesnt really return anything but I dont know a better
190+
# workaround. Should we just return cache instead? But I am afraid that
191+
# will result in extra memory allocation
192+
return torch.empty((1,), dtype=value.dtype, device="meta")

extension/llm/custom_ops/targets.bzl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,31 @@ def define_common_targets():
6262
],
6363
deps = [
6464
":custom_ops" + mkl_dep,
65+
":update_quantized_cache",
6566
"//executorch/extension/aten_util:aten_bridge",
6667
],
6768
)
6869

70+
runtime.cxx_library(
71+
name = "update_quantized_cache",
72+
srcs = ["op_update_quantized_cache.cpp"],
73+
exported_headers = ["op_update_quantized_cache.h"],
74+
exported_deps = [
75+
"//executorch/runtime/kernel:kernel_includes",
76+
"//executorch/kernels/portable/cpu:scalar_utils",
77+
"//executorch/extension/kernel_util:kernel_util",
78+
],
79+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
80+
visibility = [
81+
"//executorch/...",
82+
"//executorch/extension/llm/custom_ops/...",
83+
"@EXECUTORCH_CLIENTS",
84+
],
85+
# @lint-ignore BUCKLINT link_whole
86+
link_whole = True,
87+
force_static = True,
88+
)
89+
6990
runtime.python_library(
7091
name = "custom_ops_aot_py",
7192
srcs = [

0 commit comments

Comments
 (0)