Skip to content

Commit 3af390b

Browse files
authored
[Executorch][llama] Rename update_quantized_cache to update_cache
Differential Revision: D66041160 Pull Request resolved: #6914
1 parent 965fa27 commit 3af390b

File tree

8 files changed

+35
-50
lines changed

8 files changed

+35
-50
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -151,22 +151,14 @@ def update(self, input_pos, k_val, v_val):
151151
# instead of quantizing on their own.
152152
# But until this opting for code simplicity
153153
start_pos = input_pos[0].item()
154-
_ = torch.ops.llama.update_quantized_cache(
155-
quantized_k_val, self.k_cache, start_pos
156-
)
157-
_ = torch.ops.llama.update_quantized_cache(
158-
k_scales, self.k_cache_scales, start_pos
159-
)
160-
_ = torch.ops.llama.update_quantized_cache(
154+
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
155+
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
156+
_ = torch.ops.llama.update_cache(
161157
k_zero_points, self.k_cache_zero_points, start_pos
162158
)
163-
_ = torch.ops.llama.update_quantized_cache(
164-
quantized_v_val, self.v_cache, start_pos
165-
)
166-
_ = torch.ops.llama.update_quantized_cache(
167-
v_scales, self.v_cache_scales, start_pos
168-
)
169-
_ = torch.ops.llama.update_quantized_cache(
159+
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
160+
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
161+
_ = torch.ops.llama.update_cache(
170162
v_zero_points, self.v_cache_zero_points, start_pos
171163
)
172164

@@ -205,8 +197,8 @@ def update(self, input_pos, k_val, v_val):
205197
v_out[:, :, input_pos] = v_val
206198
else:
207199
start_pos = input_pos[0].item()
208-
_ = torch.ops.llama.update_quantized_cache(k_val, k_out, start_pos)
209-
_ = torch.ops.llama.update_quantized_cache(v_val, v_out, start_pos)
200+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
201+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
210202

211203
return k_out, v_out
212204

extension/llm/custom_ops/TARGETS

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ runtime.python_test(
2323
)
2424

2525
runtime.python_test(
26-
name = "test_update_quantized_cache",
26+
name = "test_update_cache",
2727
srcs = [
28-
"test_update_quantized_cache.py",
28+
"test_update_cache.py",
2929
],
3030
preload_deps = [
3131
":custom_ops_aot_lib",

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
1010
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
1111
#include <executorch/extension/llm/custom_ops/op_sdpa.h>
12-
#include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
12+
#include <executorch/extension/llm/custom_ops/op_update_cache.h>
1313

1414
#include <torch/library.h>
1515

@@ -127,22 +127,22 @@ at::Tensor custom_sdpa_aten(
127127
return output;
128128
}
129129

130-
Tensor& update_quantized_cache_out_no_context(
130+
Tensor& update_cache_out_no_context(
131131
const Tensor& value,
132132
Tensor& cache,
133133
const int64_t start_pos,
134134
Tensor& output) {
135135
exec_aten::RuntimeContext context{};
136-
return torch::executor::native::update_quantized_cache_out(
136+
return torch::executor::native::update_cache_out(
137137
context, value, cache, start_pos, output);
138138
}
139139

140-
at::Tensor update_quantized_cache_aten(
140+
at::Tensor update_cache_aten(
141141
const at::Tensor& value,
142142
at::Tensor& cache,
143143
const int64_t start_pos) {
144144
auto output = at::empty({1});
145-
WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3)
145+
WRAP_TO_ATEN(update_cache_out_no_context, 3)
146146
(value, cache, start_pos, output);
147147
return output;
148148
}
@@ -169,10 +169,10 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
169169
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
170170
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
171171
m.def(
172-
"update_quantized_cache(Tensor value, Tensor(a!) cache, "
172+
"update_cache(Tensor value, Tensor(a!) cache, "
173173
"SymInt start_pos) -> Tensor");
174174
m.def(
175-
"update_quantized_cache.out(Tensor value, Tensor(a!) cache, "
175+
"update_cache.out(Tensor value, Tensor(a!) cache, "
176176
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
177177
}
178178

@@ -188,11 +188,8 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
188188
m.impl(
189189
"custom_sdpa.out",
190190
WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8));
191+
m.impl("update_cache", torch::executor::native::update_cache_aten);
191192
m.impl(
192-
"update_quantized_cache",
193-
torch::executor::native::update_quantized_cache_aten);
194-
m.impl(
195-
"update_quantized_cache.out",
196-
WRAP_TO_ATEN(
197-
torch::executor::native::update_quantized_cache_out_no_context, 3));
193+
"update_cache.out",
194+
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
198195
}

extension/llm/custom_ops/op_update_quantized_cache.cpp renamed to extension/llm/custom_ops/op_update_cache.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
9+
#include <executorch/extension/llm/custom_ops/op_update_cache.h>
1010

1111
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
1212
// @lint-ignore CLANGTIDY facebook-unused-include-check
@@ -60,7 +60,7 @@ bool validate_cache_params(
6060
}
6161
} // anonymous namespace
6262

63-
Tensor& update_quantized_cache_out(
63+
Tensor& update_cache_out(
6464
RuntimeContext& ctx,
6565
const Tensor& value,
6666
Tensor& cache,
@@ -139,5 +139,5 @@ Tensor& update_quantized_cache_out(
139139
// In later diffs will rename this to update_cache.
140140
EXECUTORCH_LIBRARY(
141141
llama,
142-
"update_quantized_cache.out",
143-
torch::executor::native::update_quantized_cache_out);
142+
"update_cache.out",
143+
torch::executor::native::update_cache_out);

extension/llm/custom_ops/op_update_quantized_cache.h renamed to extension/llm/custom_ops/op_update_cache.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace executor {
1515

1616
namespace native {
1717

18-
Tensor& update_quantized_cache_out(
18+
Tensor& update_cache_out(
1919
RuntimeContext& ctx,
2020
const Tensor& value,
2121
Tensor& cache,

extension/llm/custom_ops/sdpa_with_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ def _validate_update_cache_params(
203203
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
204204

205205

206-
@impl(custom_ops_lib, "update_quantized_cache", "Meta")
207-
def update_quantized_cache_meta(
206+
@impl(custom_ops_lib, "update_cache", "Meta")
207+
def update_cache_meta(
208208
value,
209209
cache,
210210
start_pos,

extension/llm/custom_ops/targets.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ def define_common_targets():
2222
"op_fallback.cpp",
2323
"op_fast_hadamard_transform.cpp",
2424
"op_sdpa.cpp",
25-
"op_update_quantized_cache.cpp",
25+
"op_update_cache.cpp",
2626
],
2727
exported_headers = [
2828
"op_fallback.h",
2929
"op_fast_hadamard_transform.h",
3030
"op_sdpa.h",
31-
"op_update_quantized_cache.h",
31+
"op_update_cache.h",
3232
],
3333
preprocessor_flags = get_vec_preprocessor_flags(),
3434
exported_deps = [

extension/llm/custom_ops/test_update_quantized_cache.py renamed to extension/llm/custom_ops/test_update_cache.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,13 @@ def _update_and_validate(
6767
self._update_k(start_pos, k, k_scales, k_zero_points)
6868
self._update_v(start_pos, v, v_scales, v_zero_points)
6969

70-
torch.ops.llama.update_quantized_cache(k, k_cache, start_pos)
71-
torch.ops.llama.update_quantized_cache(k_scales, k_scales_cache, start_pos)
72-
torch.ops.llama.update_quantized_cache(
73-
k_zero_points, k_zero_points_cache, start_pos
74-
)
70+
torch.ops.llama.update_cache(k, k_cache, start_pos)
71+
torch.ops.llama.update_cache(k_scales, k_scales_cache, start_pos)
72+
torch.ops.llama.update_cache(k_zero_points, k_zero_points_cache, start_pos)
7573

76-
torch.ops.llama.update_quantized_cache(v, v_cache, start_pos)
77-
torch.ops.llama.update_quantized_cache(v_scales, v_scales_cache, start_pos)
78-
torch.ops.llama.update_quantized_cache(
79-
v_zero_points, v_zero_points_cache, start_pos
80-
)
74+
torch.ops.llama.update_cache(v, v_cache, start_pos)
75+
torch.ops.llama.update_cache(v_scales, v_scales_cache, start_pos)
76+
torch.ops.llama.update_cache(v_zero_points, v_zero_points_cache, start_pos)
8177

8278
self.assertTrue(torch.allclose(k_cache, self.quantized_k_cache))
8379
self.assertTrue(torch.allclose(v_cache, self.quantized_v_cache))

0 commit comments

Comments
 (0)