Skip to content

Commit 10f89f8

Browse files
authored
Merge branch 'main' into cherry-pick-13261-by-pytorch_bot_bot_
2 parents 32ef584 + 3a02146 commit 10f89f8

File tree

11 files changed

+595
-79
lines changed

11 files changed

+595
-79
lines changed

.github/workflows/build-presets.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ on:
66
branches:
77
- main
88
- release/*
9-
paths:
10-
- .github/workflows/build-presets.yml
119
workflow_dispatch:
1210

1311
concurrency:

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 109 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -219,37 +219,42 @@ def post_process():
219219

220220

221221
def smart_mask_updater(
222-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
222+
_, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
223223
):
224-
# Update the KV cache input for the next inference when the position exceeds the autoregressive length.
225-
if pos >= ar_len:
224+
# ar_len is unused in smart mask
225+
max_cache_len = k_caches[0].size(-1)
226+
if pos + n_updates <= max_cache_len:
226227
for i, k_cache in enumerate(k_caches):
227-
k_cache[:, :, pos - ar_len] = new_k_caches[i][:, :, 0]
228+
k_cache[:, :, pos : pos + n_updates] = new_k_caches[i][:, :, :n_updates]
228229

229230
for i, v_cache in enumerate(v_caches):
230-
v_cache[:, pos - ar_len, :] = new_v_caches[i][:, 0, :]
231-
atten_mask[:, :, pos - ar_len] = 0
231+
v_cache[:, pos : pos + n_updates, :] = new_v_caches[i][:, :n_updates, :]
232+
atten_mask[:, :, pos : pos + n_updates] = 0
233+
pos += n_updates
232234

233-
pos += 1
234235
return (atten_mask, pos, k_caches, v_caches)
235236

236237

237238
def shift_pointer_updater(
238-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
239+
ar_len, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
239240
):
240-
# Update the KV cache input for the next inference when the position exceeds the autoregressive length.
241-
if pos >= ar_len:
241+
max_cache_len = k_caches[0].size(-1)
242+
if pos + n_updates <= max_cache_len:
242243
k_caches = [
243-
torch.cat([k_cache[:, :, 1:], new_k_caches[i][:, :, :1]], dim=-1)
244+
torch.cat(
245+
[k_cache[:, :, n_updates:], new_k_caches[i][:, :, :n_updates]], dim=-1
246+
)
244247
for i, k_cache in enumerate(k_caches)
245248
]
246249
v_caches = [
247-
torch.cat([v_cache[:, 1:, :], new_v_caches[i][:, :1, :]], dim=1)
250+
torch.cat(
251+
[v_cache[:, n_updates:, :], new_v_caches[i][:, :n_updates, :]], dim=1
252+
)
248253
for i, v_cache in enumerate(v_caches)
249254
]
250-
atten_mask[:, :, -pos - 1] = 0
255+
atten_mask[:, :, -pos - n_updates - ar_len : -pos - ar_len] = 0
256+
pos += n_updates
251257

252-
pos += 1
253258
return (atten_mask, pos, k_caches, v_caches)
254259

255260

@@ -269,70 +274,121 @@ def kv_inference(
269274
# TODO: change criteria & support batch inputs if necessary
270275
all_pos = torch.arange(0, max_seq_len, 1, dtype=torch.int32).unsqueeze(0)
271276

272-
token_list, result_logits = [], []
277+
prompt_token_list, total_token_list, result_logits = [], [], []
273278

274279
if isinstance(prompt, str):
275280
# Llama2 tokenizer has no special tokens
276281
if isinstance(tokenizer, (SentencePieceTokenizer, HuggingFaceTokenizer)):
277-
token_list = tokenizer.encode(prompt, bos=True, eos=False)
282+
prompt_token_list = tokenizer.encode(prompt, bos=True, eos=False)
278283
elif isinstance(tokenizer, TiktokenTokenizer):
279-
token_list = tokenizer.encode(
284+
prompt_token_list = tokenizer.encode(
280285
prompt, bos=True, eos=False, allowed_special="all"
281286
)
282287
else:
283288
raise RuntimeError("Unknown tokenizer")
284289
else:
285290
# pyre-ignore
286-
token_list = prompt.flatten().tolist()
287-
pos = len(token_list) if len(token_list) < ar_len else ar_len
291+
prompt_token_list = prompt.flatten().tolist()
292+
total_token_list = prompt_token_list
288293
dtype = torch.int64 if use_i64_token else torch.int32
289294

290295
with torch.no_grad():
291-
while token_list[-1] != tokenizer.eos_id and pos < max_seq_len:
292-
tmp_token_list = torch.tensor(
293-
token_list[pos - ar_len : pos], dtype=dtype
294-
).reshape(1, -1)
295-
tmp_pos = all_pos[:, pos - ar_len : pos]
296-
tmp_atten_mask = atten_mask
297-
if pos < ar_len:
298-
tmp_token_list = torch.cat(
299-
[
300-
torch.zeros((1, ar_len - pos), dtype=dtype),
301-
torch.tensor(token_list, dtype=dtype).reshape(1, -1),
302-
],
303-
dim=1,
304-
)
305-
tmp_pos = torch.cat(
306-
[
307-
torch.zeros((1, ar_len - pos), dtype=torch.int32),
308-
all_pos[:, :pos],
309-
],
310-
dim=1,
311-
)
312-
tmp_atten_mask = torch.cat(
313-
[
314-
torch.ones(1, ar_len, max_seq_len - pos) * -255.0,
315-
atten_mask[:, :, -pos:],
316-
],
317-
dim=-1,
318-
)
296+
# Phase 1: Prefill the prompt in ar_len chunks.
297+
num_prompt_tokens = len(prompt_token_list)
298+
pos = 0 # Tracks how many prompt tokens have been processed.
299+
while pos < num_prompt_tokens:
300+
chunk_start_idx = pos
301+
# Take a chunk of prompt tokens, up to ar_len length.
302+
chunk_end_idx = min(num_prompt_tokens, pos + ar_len)
303+
actual_chunk_tokens = prompt_token_list[chunk_start_idx:chunk_end_idx]
304+
num_tokens_in_chunk = len(actual_chunk_tokens)
305+
306+
# Prepare tmp_token_list (padded with zeros).
307+
tmp_token_list = torch.zeros((1, ar_len), dtype=dtype)
308+
tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor(
309+
actual_chunk_tokens, dtype=dtype
310+
)
319311

312+
# Prepare tmp_pos (padded with zeros).
313+
tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32)
314+
tmp_pos[0, :num_tokens_in_chunk] = all_pos[
315+
0,
316+
pos : pos + num_tokens_in_chunk,
317+
]
318+
319+
# Run inference.
320320
logits, new_k_caches, new_v_caches = module(
321321
tmp_token_list,
322-
tmp_atten_mask,
322+
atten_mask,
323323
tmp_pos,
324324
*k_caches,
325325
*v_caches,
326326
)
327327
if collect_logits:
328-
result_logits.append(logits)
328+
result_logits.append(logits[:, :num_tokens_in_chunk])
329+
330+
# Update the pos, KV cache and attention mask.
329331
atten_mask, pos, k_caches, v_caches = kv_updater(
330-
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
332+
ar_len,
333+
num_tokens_in_chunk,
334+
atten_mask,
335+
pos,
336+
k_caches,
337+
v_caches,
338+
new_k_caches,
339+
new_v_caches,
340+
)
341+
# Append the last run logits to the total_token_list.
342+
total_token_list.append(
343+
torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item()
344+
)
345+
346+
# Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached.
347+
# When run on wikitext for ppl evaluation, this while-loop is not expected to run.
348+
max_cache_len = max_seq_len - ar_len
349+
num_tokens = len(total_token_list)
350+
while total_token_list[-1] != tokenizer.eos_id and num_tokens < max_seq_len:
351+
chunk_start_idx = min(pos, max_cache_len)
352+
# Take a chunk of generated tokens, up to ar_len length.
353+
chunk_end_idx = num_tokens
354+
actual_chunk_tokens = total_token_list[chunk_start_idx:chunk_end_idx]
355+
num_tokens_in_chunk = len(actual_chunk_tokens)
356+
357+
# Prepare tmp_token_list (padded with zeros).
358+
tmp_token_list = torch.zeros((1, ar_len), dtype=dtype)
359+
tmp_token_list[0, :num_tokens_in_chunk] = torch.tensor(
360+
actual_chunk_tokens, dtype=dtype
361+
)
362+
363+
# Prepare tmp_pos (padded with zeros).
364+
tmp_pos = torch.zeros((1, ar_len), dtype=torch.int32)
365+
tmp_pos[0, :num_tokens_in_chunk] = all_pos[0, chunk_start_idx:chunk_end_idx]
366+
367+
logits, new_k_caches, new_v_caches = module(
368+
tmp_token_list,
369+
atten_mask,
370+
tmp_pos,
371+
*k_caches,
372+
*v_caches,
331373
)
332-
if pos > len(token_list):
333-
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
374+
if collect_logits:
375+
result_logits.append(logits[:, :num_tokens_in_chunk])
334376

335-
logging.info(f"kv inference result:\n{tokenizer.decode(token_list)}")
377+
atten_mask, pos, k_caches, v_caches = kv_updater(
378+
ar_len,
379+
1,
380+
atten_mask,
381+
pos,
382+
k_caches,
383+
v_caches,
384+
new_k_caches,
385+
new_v_caches,
386+
)
387+
total_token_list.append(
388+
torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item()
389+
)
390+
num_tokens = len(total_token_list)
391+
logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}")
336392
if collect_logits:
337393
result_logits = torch.cat(result_logits, dim=1)
338394
return result_logits
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
17+
using Tensor = executorch::aten::Tensor;
18+
19+
template <typename T>
20+
using OptionalArrayRef = executorch::aten::OptionalArrayRef<T>;
21+
22+
/**
23+
* _clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]?
24+
* dim_order=None, Tensor(a!) out) -> Tensor(a!)
25+
*
26+
* Clones via element-wise copy while preserving dim_order.
27+
*/
28+
Tensor& _clone_dim_order_out(
29+
KernelRuntimeContext& ctx,
30+
const Tensor& self,
31+
bool non_blocking,
32+
OptionalArrayRef<int64_t> dim_order,
33+
Tensor& out) {
34+
(void)ctx;
35+
36+
// Ensure input and output dtype match.
37+
ET_KERNEL_CHECK(
38+
ctx, self.scalar_type() == out.scalar_type(), InvalidArgument, out);
39+
40+
// Ensure output has the same layout as input or matches dim_order.
41+
ET_KERNEL_CHECK(
42+
ctx,
43+
check__to_dim_order_copy_args(self, non_blocking, dim_order, out),
44+
InvalidArgument,
45+
out);
46+
47+
// Ensure input and output shapes match, resizing if necessary.
48+
ET_KERNEL_CHECK(
49+
ctx,
50+
resize_tensor(out, self.sizes()) == torch::executor::Error::Ok,
51+
InvalidArgument,
52+
out);
53+
54+
if (self.numel() == 0) {
55+
return out;
56+
}
57+
58+
// Select the correct input dtype and copy the tensors.
59+
ET_SWITCH_REALHBBF16_TYPES(
60+
self.scalar_type(),
61+
ctx,
62+
"dim_order_ops::_clone_dim_order.out",
63+
CTYPE,
64+
[&] { _to_dim_order_copy_impl<CTYPE, CTYPE>(self, out); });
65+
66+
return out;
67+
}
68+
69+
Tensor& _clone_dim_order_out(
70+
const Tensor& self,
71+
bool non_blocking,
72+
OptionalArrayRef<int64_t> dim_order,
73+
Tensor& out) {
74+
executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext context{};
75+
return _clone_dim_order_out(context, self, non_blocking, dim_order, out);
76+
}
77+
78+
} // namespace native
79+
} // namespace executor
80+
} // namespace torch

kernels/portable/cpu/op__to_dim_order_copy.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,6 @@ using OptionalArrayRef = executorch::aten::OptionalArrayRef<T>;
2929
template <typename T>
3030
using Optional = std::optional<T>;
3131

32-
namespace {
33-
34-
template <typename SELF_CTYPE, typename OUT_CTYPE>
35-
void _to_dim_order_copy_impl(const Tensor& self, Tensor& out) {
36-
auto self_data = self.mutable_data_ptr<SELF_CTYPE>();
37-
auto out_data = out.mutable_data_ptr<OUT_CTYPE>();
38-
39-
// Here we make a slightly off-label use of
40-
// BroadcastIndexesRange. It always assumes it doesn't have to care
41-
// about different dim_order between input and output, but we can
42-
// just force it to respect strides (and thus dim_order) for its
43-
// inputs using support_noncontiguous_input_tensors=true, and then pretend
44-
// the output is just another input.
45-
for (const auto [unused_index, self_data_index, out_data_index] :
46-
BroadcastIndexesRange<2, /*support_noncontiguous_input_tensors=*/true>(
47-
/*dummy output*/ self, self, out)) {
48-
(void)unused_index;
49-
out_data[out_data_index] =
50-
static_cast<OUT_CTYPE>(self_data[self_data_index]);
51-
}
52-
}
53-
} // namespace
54-
5532
// _to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]?
5633
// dim_order=None, Tensor(a!) out) -> Tensor(a!)
5734
Tensor& _to_dim_order_copy_out(

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010
#include <c10/util/irange.h>
1111

12+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314

1415
namespace torch {
@@ -77,6 +78,29 @@ void as_strided_copy(
7778
}
7879
}
7980

81+
/**
82+
* Copies and casts a tensor while preserving input dim_order.
83+
*/
84+
template <typename SELF_CTYPE, typename OUT_CTYPE>
85+
void _to_dim_order_copy_impl(const Tensor& self, Tensor& out) {
86+
auto self_data = self.mutable_data_ptr<SELF_CTYPE>();
87+
auto out_data = out.mutable_data_ptr<OUT_CTYPE>();
88+
89+
// Here we make a slightly off-label use of
90+
// BroadcastIndexesRange. It always assumes it doesn't have to care
91+
// about different dim_order between input and output, but we can
92+
// just force it to respect strides (and thus dim_order) for its
93+
// inputs using support_noncontiguous_input_tensors=true, and then pretend
94+
// the output is just another input.
95+
for (const auto [unused_index, self_data_index, out_data_index] :
96+
BroadcastIndexesRange<2, /*support_noncontiguous_input_tensors=*/true>(
97+
/*dummy output*/ self, self, out)) {
98+
(void)unused_index;
99+
out_data[out_data_index] =
100+
static_cast<OUT_CTYPE>(self_data[self_data_index]);
101+
}
102+
}
103+
80104
bool check_cat_args(
81105
executorch::aten::ArrayRef<Tensor> tensors,
82106
int64_t dim,

kernels/portable/cpu/util/targets.bzl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ def define_common_targets():
147147
"copy_ops_util.h",
148148
],
149149
compiler_flags = ["-Wno-missing-prototypes"],
150+
exported_deps = [
151+
":broadcast_util",
152+
],
150153
deps = [
151154
"//executorch/runtime/kernel:kernel_includes",
152155
],
@@ -348,7 +351,6 @@ def define_common_targets():
348351
],
349352
)
350353

351-
352354
runtime.cxx_library(
353355
name = "arange_util{}".format(suffix),
354356
srcs = ["arange_util.cpp"],

0 commit comments

Comments
 (0)