Skip to content

Commit d364c67

Browse files
committed
Add update_quantized_cache op
Why? - ton of copies due to functionalization - mutable buffer support without such custom inplace ops will results in giant copies at the end - Making inplace ops work will likely take longer and not clear safe path Differential Revision: [D62301838](https://our.internmc.facebook.com/intern/diff/D62301838/) ghstack-source-id: 243859228 Pull Request resolved: #5527
1 parent b4f5ec5 commit d364c67

File tree

7 files changed

+415
-2
lines changed

7 files changed

+415
-2
lines changed

extension/llm/custom_ops/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@ runtime.python_test(
2222
],
2323
)
2424

25+
runtime.python_test(
26+
name = "test_update_quantized_cache",
27+
srcs = [
28+
"test_update_quantized_cache.py",
29+
],
30+
preload_deps = [
31+
":custom_ops_aot_lib",
32+
],
33+
deps = [
34+
"//caffe2:torch",
35+
],
36+
)
37+
2538
runtime.python_test(
2639
name = "test_preprocess_custom_ops",
2740
srcs = [

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
1010
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
1111
#include <executorch/extension/llm/custom_ops/op_sdpa.h>
12+
#include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
1213

1314
#include <torch/library.h>
1415

1516
namespace torch {
1617
namespace executor {
1718

1819
namespace native {
19-
namespace {
2020
Tensor& sdpa_with_kv_cache_out_no_context(
2121
const Tensor& q_projected,
2222
const Tensor& k_projected,
@@ -81,7 +81,27 @@ at::Tensor sdpa_with_kv_cache_aten(
8181
output);
8282
return output;
8383
}
84-
} // namespace
84+
85+
Tensor& update_quantized_cache_out_no_context(
86+
const Tensor& value,
87+
Tensor& cache,
88+
const int64_t start_pos,
89+
Tensor& output) {
90+
exec_aten::RuntimeContext context{};
91+
return torch::executor::native::update_quantized_cache_out(
92+
context, value, cache, start_pos, output);
93+
}
94+
95+
at::Tensor update_quantized_cache_aten(
96+
const at::Tensor& value,
97+
at::Tensor& cache,
98+
const int64_t start_pos) {
99+
auto output = at::empty({1});
100+
WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3)
101+
(value, cache, start_pos, output);
102+
return output;
103+
}
104+
85105
} // namespace native
86106
} // namespace executor
87107
} // namespace torch
@@ -95,6 +115,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
95115
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
96116
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
97117
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
118+
m.def(
119+
"update_quantized_cache(Tensor value, Tensor(a!) cache, "
120+
"SymInt start_pos) -> Tensor");
121+
m.def(
122+
"update_quantized_cache.out(Tensor value, Tensor(a!) cache, "
123+
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
98124
}
99125

100126
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
@@ -105,3 +131,14 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
105131
WRAP_TO_ATEN(
106132
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
107133
}
134+
135+
// TODO: Rename this file to op_custom_ops_aot.cpp
136+
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
137+
m.impl(
138+
"update_quantized_cache",
139+
torch::executor::native::update_quantized_cache_aten);
140+
m.impl(
141+
"update_quantized_cache.out",
142+
WRAP_TO_ATEN(
143+
torch::executor::native::update_quantized_cache_out_no_context, 3));
144+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
10+
11+
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
12+
// @lint-ignore CLANGTIDY facebook-unused-include-check
13+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14+
15+
#include <array>
16+
// patternlint-disable-next-line executorch-cpp-nostdinc
17+
#include <vector>
18+
19+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
20+
21+
namespace torch {
22+
namespace executor {
23+
24+
namespace native {
25+
26+
namespace {
27+
bool validate_cache_params(
28+
const Tensor& quantized_value,
29+
const Tensor& quantized_cache,
30+
int64_t start_pos,
31+
int64_t seq_length) {
32+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
33+
quantized_cache.dim() == 4, "quantized cache must be a 4D tensor");
34+
35+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
36+
quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
37+
38+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
39+
start_pos < quantized_cache.size(1),
40+
"start_pos must be less than cache size at dim 1");
41+
42+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
43+
(start_pos + seq_length) <= quantized_cache.size(1),
44+
"start_post + seq_length must be less than max seq length supported by cache."
45+
"start pos: %" PRId64 ", seq_length: %" PRId64
46+
"."
47+
"cache size: %zd",
48+
start_pos,
49+
seq_length,
50+
quantized_cache.size(1));
51+
52+
// Make sure they are in contiguous dim order
53+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
54+
is_contiguous_dim_order(
55+
quantized_cache.dim_order().data(), quantized_cache.dim()),
56+
"quantized cache must be in contiguous dim order");
57+
58+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
59+
is_contiguous_dim_order(
60+
quantized_value.dim_order().data(), quantized_value.dim()),
61+
"quantized value must be in contiguous dim order");
62+
63+
return true;
64+
}
65+
} // anonymous namespace
66+
67+
Tensor& update_quantized_cache_out(
68+
RuntimeContext& ctx,
69+
const Tensor& value,
70+
Tensor& cache,
71+
const int64_t start_pos,
72+
Tensor& output) {
73+
(void)ctx;
74+
int64_t seq_len = value.size(1);
75+
ET_KERNEL_CHECK(
76+
ctx,
77+
validate_cache_params(value, cache, start_pos, seq_len),
78+
InvalidArgument,
79+
output);
80+
ET_CHECK_MSG(value.dim() == 4, "value must be a 4D tensor");
81+
82+
ET_CHECK_MSG(value.size(0) == 1, "value must have batch size of 1");
83+
ET_CHECK_MSG(cache.size(0) == 1, "cache must have batch size of 1");
84+
const void* value_data = value.const_data_ptr();
85+
void* cache_data = cache.mutable_data_ptr();
86+
87+
ET_CHECK_MSG(value_data != nullptr, "projected_value data is null");
88+
ET_CHECK_MSG(cache_data, "cache data is null");
89+
90+
auto strides = cache.strides();
91+
exec_aten::StridesType seq_dim_stride = strides[1];
92+
exec_aten::SizesType pos_offset = start_pos * seq_dim_stride;
93+
exec_aten::SizesType pos_offset_bytes = pos_offset * value.element_size();
94+
exec_aten::SizesType num_bytes = value.numel() * value.element_size();
95+
// NOLINTNEXTLINE
96+
std::memcpy((uint8_t*)cache_data + pos_offset_bytes, value_data, num_bytes);
97+
98+
// Noone uses output. Just a placeholder.
99+
return output;
100+
}
101+
} // namespace native
102+
} // namespace executor
103+
} // namespace torch
104+
105+
// Really this is just an inplace tensor update op
106+
// which makes assumption on the rank of a tensor,
107+
// and the dim order (memory layout) of the tensor.
108+
// Furthermore assumes that the indexing is along
109+
// sequence dimension (dim 1) of the tensor.
110+
// In later diffs will rename this to update_cache.
111+
EXECUTORCH_LIBRARY(
112+
llama,
113+
"update_quantized_cache.out",
114+
torch::executor::native::update_quantized_cache_out);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
namespace native {
17+
18+
Tensor& update_quantized_cache_out(
19+
RuntimeContext& ctx,
20+
const Tensor& value,
21+
Tensor& cache,
22+
const int64_t start_pos,
23+
Tensor& output);
24+
} // namespace native
25+
} // namespace executor
26+
} // namespace torch

extension/llm/custom_ops/sdpa_with_kv_cache.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from torch.library import impl
1919

20+
# TODO rename this file to custom_ops_meta_registration.py
2021
try:
2122
op = torch.ops.llama.sdpa_with_kv_cache.default
2223
assert op is not None
@@ -138,3 +139,54 @@ def fast_hadamard_transform_meta(mat):
138139
# assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!")
139140
# assert(mat.is_contiguous(), "input matrix must be contiguous currently!")
140141
return torch.empty_like(mat)
142+
143+
144+
def _validate_update_cache_params(
145+
value,
146+
cache,
147+
start_pos,
148+
):
149+
seq_len = value.size(1)
150+
assert (
151+
value.dim() == 4
152+
), f"Expected value to be 4 dimensional but got {value.dim()} dimensions."
153+
154+
assert (
155+
value.dtype == cache.dtype
156+
), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}"
157+
158+
for i in [0, 2, 3]:
159+
assert value.size(i) == cache.size(
160+
i
161+
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"
162+
163+
torch._check_is_size(start_pos)
164+
# Setting to arbitrary limit of 256 for now since there is no way
165+
# to plumb this information from model config
166+
torch._check(start_pos < cache.size(1))
167+
assert start_pos < cache.size(
168+
1
169+
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
170+
171+
torch._check((start_pos + seq_len) < cache.size(1))
172+
assert (start_pos + seq_len) < cache.size(
173+
1
174+
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
175+
176+
177+
@impl(custom_ops_lib, "update_quantized_cache", "Meta")
178+
def update_quantized_cache_meta(
179+
value,
180+
cache,
181+
start_pos,
182+
):
183+
_validate_update_cache_params(
184+
value,
185+
cache,
186+
start_pos,
187+
)
188+
189+
# Update cache doesnt really return anything but I dont know a better
190+
# workaround. Should we just return cache instead? But I am afraid that
191+
# will result in extra memory allocation
192+
return torch.empty((1,), dtype=value.dtype, device="meta")

extension/llm/custom_ops/targets.bzl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,31 @@ def define_common_targets():
6262
],
6363
deps = [
6464
":custom_ops" + mkl_dep,
65+
":update_quantized_cache",
6566
"//executorch/extension/aten_util:aten_bridge",
6667
],
6768
)
6869

70+
runtime.cxx_library(
71+
name = "update_quantized_cache",
72+
srcs = ["op_update_quantized_cache.cpp"],
73+
exported_headers = ["op_update_quantized_cache.h"],
74+
exported_deps = [
75+
"//executorch/runtime/kernel:kernel_includes",
76+
"//executorch/kernels/portable/cpu:scalar_utils",
77+
"//executorch/extension/kernel_util:kernel_util",
78+
],
79+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
80+
visibility = [
81+
"//executorch/...",
82+
"//executorch/extension/llm/custom_ops/...",
83+
"@EXECUTORCH_CLIENTS",
84+
],
85+
# @lint-ignore BUCKLINT link_whole
86+
link_whole = True,
87+
force_static = True,
88+
)
89+
6990
runtime.python_library(
7091
name = "custom_ops_aot_py",
7192
srcs = [

0 commit comments

Comments
 (0)