Skip to content

Commit 47b121a

Browse files
committed
Support cache positions
1 parent 9c4357a commit 47b121a

File tree

8 files changed

+208
-15
lines changed

8 files changed

+208
-15
lines changed

examples/models/phi-3-mini/README.md

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,7 @@ python -m examples.models.phi-3-mini.export_phi-3-mini -c "4k" -s 128 -o phi-3-m
2222
3. Build and run the model.
2323
- Build executorch with optimized CPU performance as follows. Build options available [here](https://github.com/pytorch/executorch/blob/main/CMakeLists.txt#L59).
2424
```
25-
cmake -DPYTHON_EXECUTABLE=python \
26-
-DCMAKE_INSTALL_PREFIX=cmake-out \
27-
-DEXECUTORCH_ENABLE_LOGGING=1 \
28-
-DCMAKE_BUILD_TYPE=Release \
29-
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
30-
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
31-
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
32-
-DEXECUTORCH_BUILD_XNNPACK=ON \
33-
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
34-
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
35-
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
36-
-Bcmake-out .
25+
cmake --preset llm -DCMAKE_INSTALL_PREFIX=cmake-out
3726
3827
cmake --build cmake-out -j16 --target install --config Release
3928
```

examples/models/phi-3-mini/export_phi-3-mini.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,15 @@ def export(args) -> None:
5151
torch.tensor(
5252
[[1048, 263, 931, 746]], dtype=torch.long, requires_grad=False
5353
),
54+
torch.tensor([[0, 1, 2, 3]], dtype=torch.long, requires_grad=False),
5455
)
5556
dynamic_shapes = {
5657
"input_ids": {
5758
1: torch.export.Dim("sequence_length", min=1, max=args.seq_len)
58-
}
59+
},
60+
"cache_positions": {
61+
1: torch.export.Dim("sequence_length", min=1, max=args.seq_len)
62+
},
5963
}
6064

6165
xnnpack_quant_config = get_symmetric_quantization_config(

extension/llm/runner/text_llm_runner.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ static constexpr auto kMaxContextLen = "get_max_context_len";
3232
static constexpr auto kVocabSize = "get_vocab_size";
3333
static constexpr auto kUseKVCache = "use_kv_cache";
3434
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
35+
static constexpr auto kUseCachePositions = "use_cache_positions";
3536

3637
TextLLMRunner::TextLLMRunner(
3738
std::unordered_map<std::string, int64_t> metadata,
@@ -306,6 +307,7 @@ std::unordered_map<std::string, int64_t> get_llm_metadata(
306307
{llm::kMaxContextLen, 128},
307308
{llm::kUseKVCache, true},
308309
{llm::kUseSDPAWithKVCache, false},
310+
{llm::kUseCachePositions, false},
309311
});
310312

311313
// Read metadata from the model
@@ -335,7 +337,24 @@ std::unordered_map<std::string, int64_t> get_llm_metadata(
335337
// Set tokenizer-related metadata
336338
metadata[llm::kBosId] = tokenizer->bos_tok();
337339
metadata[llm::kVocabSize] = tokenizer->vocab_size();
338-
return metadata;
340+
341+
// Override metadata using the module's method_meta
342+
auto method_meta_result = module->method_meta("forward");
343+
if (method_meta_result.error() != Error::Ok) {
344+
ET_LOG(Error, "Failed reading method meta");
345+
return metadata;
346+
}
347+
auto method_meta = method_meta_result.get();
348+
// If only 1 input, we are not using kv cache
349+
metadata[llm::kUseKVCache] = method_meta.num_inputs() > 1;
350+
351+
if (method_meta.num_inputs() == 1) {
352+
return metadata;
353+
}
354+
// Check if we are using cache positions instead of input pos.
355+
auto second_input_info = method_meta.input_tensor_meta(1).get();
356+
// For input_pos, size is [1], for cache_positions, size is [1, max_seq_len]
357+
metadata[llm::kUseCachePositions] = second_input_info.sizes().size() == 2;
339358
}
340359

341360
std::unordered_set<uint64_t> get_eos_ids(

extension/llm/runner/text_prefiller.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ namespace llm {
1919
TextPrefiller::TextPrefiller(
2020
TextDecoderRunner* text_decoder_runner,
2121
bool use_kv_cache,
22+
bool use_cache_positions,
2223
bool enable_parallel_prefill,
2324
int64_t max_seq_len)
2425
: text_decoder_runner_(text_decoder_runner),
2526
use_kv_cache_(use_kv_cache),
27+
use_cache_positions_(use_cache_positions),
2628
enable_parallel_prefill_(enable_parallel_prefill),
2729
max_seq_len_(max_seq_len > 0 ? max_seq_len : 128) {}
2830

extension/llm/runner/text_prefiller.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ class ET_EXPERIMENTAL TextPrefiller {
2121
public:
2222
TextPrefiller(
2323
TextDecoderRunner* text_decoder_runner,
24-
bool use_kv_cache_,
24+
bool use_kv_cache,
25+
bool use_cache_positions,
2526
bool enable_parallel_prefill,
2627
int64_t max_seq_len = 128);
2728

@@ -75,6 +76,7 @@ class ET_EXPERIMENTAL TextPrefiller {
7576
*/
7677
TextDecoderRunner* text_decoder_runner_;
7778
bool use_kv_cache_;
79+
bool use_cache_positions_;
7880
bool enable_parallel_prefill_;
7981
int64_t max_seq_len_;
8082
};

extension/tensor/tensor_ptr_maker.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,5 +186,59 @@ TensorPtr randint_strided(
186186
std::uniform_int_distribution<int64_t>(low, high - 1));
187187
}
188188

189+
TensorPtr arange(
190+
executorch::aten::Scalar start,
191+
executorch::aten::Scalar end,
192+
executorch::aten::Scalar step,
193+
executorch::aten::ScalarType type,
194+
executorch::aten::TensorShapeDynamism dynamism) {
195+
// Calculate the number of elements in the range
196+
double start_val, end_val, step_val;
197+
198+
if (start.isFloatingPoint()) {
199+
start_val = start.to<double>();
200+
} else if (start.isIntegral(/*includeBool=*/false)) {
201+
start_val = static_cast<double>(start.to<int64_t>());
202+
} else {
203+
ET_CHECK_MSG(false, "start must be a number");
204+
}
205+
206+
if (end.isFloatingPoint()) {
207+
end_val = end.to<double>();
208+
} else if (end.isIntegral(/*includeBool=*/false)) {
209+
end_val = static_cast<double>(end.to<int64_t>());
210+
} else {
211+
ET_CHECK_MSG(false, "end must be a number");
212+
}
213+
214+
if (step.isFloatingPoint()) {
215+
step_val = step.to<double>();
216+
} else if (step.isIntegral(/*includeBool=*/false)) {
217+
step_val = static_cast<double>(step.to<int64_t>());
218+
} else {
219+
ET_CHECK_MSG(false, "step must be a number");
220+
}
221+
222+
ET_CHECK_MSG(step_val != 0, "step cannot be zero");
223+
224+
// Calculate the number of elements
225+
int64_t numel =
226+
static_cast<int64_t>(std::ceil((end_val - start_val) / step_val));
227+
numel = std::max(int64_t(0), numel); // Ensure non-negative
228+
229+
// Create a 1D tensor with the calculated size
230+
auto tensor = empty_strided({numel}, {1}, type, dynamism);
231+
232+
// Fill the tensor with values from start to end with step
233+
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "arange", CTYPE, [&] {
234+
CTYPE* data = tensor->mutable_data_ptr<CTYPE>();
235+
for (int64_t i = 0; i < numel; ++i) {
236+
data[i] = static_cast<CTYPE>(start_val + i * step_val);
237+
}
238+
});
239+
240+
return tensor;
241+
}
242+
189243
} // namespace extension
190244
} // namespace executorch

extension/tensor/tensor_ptr_maker.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,5 +683,41 @@ inline TensorPtr randint(
683683
return randint_strided(low, high, std::move(sizes), {}, type, dynamism);
684684
}
685685

686+
/**
687+
* Creates a 2-D tensor (sizes=[1, max]) with values from `start` to `end`
688+
* (exclusive) with step size `step`.
689+
*
690+
* @param start The starting value of the sequence.
691+
* @param end The ending value of the sequence (exclusive).
692+
* @param step The step size between values in the sequence.
693+
* @param type The scalar type of the tensor elements.
694+
* @param dynamism Specifies whether the tensor's shape is static or dynamic.
695+
* @return A TensorPtr instance managing the newly created Tensor.
696+
*/
697+
TensorPtr arange(
698+
executorch::aten::Scalar start,
699+
executorch::aten::Scalar end,
700+
executorch::aten::Scalar step = 1,
701+
executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
702+
executorch::aten::TensorShapeDynamism dynamism =
703+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND);
704+
705+
/**
706+
* Creates a 2-D tensor (sizes=[1, max]) with values from 0 to `end` (exclusive)
707+
* with step size 1.
708+
*
709+
* @param end The ending value of the sequence (exclusive).
710+
* @param type The scalar type of the tensor elements.
711+
* @param dynamism Specifies whether the tensor's shape is static or dynamic.
712+
* @return A TensorPtr instance managing the newly created Tensor.
713+
*/
714+
inline TensorPtr arange(
715+
executorch::aten::Scalar end,
716+
executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
717+
executorch::aten::TensorShapeDynamism dynamism =
718+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
719+
return arange(0, end, 1, type, dynamism);
720+
}
721+
686722
} // namespace extension
687723
} // namespace executorch

extension/tensor/test/tensor_ptr_maker_test.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,3 +506,90 @@ TEST_F(TensorPtrMakerTest, CreateRandnTensorWithIntType) {
506506
EXPECT_EQ(val, 0);
507507
}
508508
}
509+
510+
TEST_F(TensorPtrMakerTest, CreateArangeTensorWithDefaultStartAndStep) {
511+
auto tensor = arange(5);
512+
513+
EXPECT_EQ(tensor->dim(), 1);
514+
EXPECT_EQ(tensor->size(0), 5);
515+
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::Float);
516+
517+
for (auto i = 0; i < tensor->numel(); ++i) {
518+
auto val = tensor->const_data_ptr<float>()[i];
519+
EXPECT_EQ(val, static_cast<float>(i));
520+
}
521+
}
522+
523+
TEST_F(TensorPtrMakerTest, CreateArangeTensorWithStartEndStep) {
524+
auto tensor = arange(2, 10, 2);
525+
526+
EXPECT_EQ(tensor->dim(), 1);
527+
EXPECT_EQ(tensor->size(0), 4); // (10-2)/2 = 4 elements
528+
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::Float);
529+
530+
for (auto i = 0; i < tensor->numel(); ++i) {
531+
auto val = tensor->const_data_ptr<float>()[i];
532+
EXPECT_EQ(val, static_cast<float>(2 + i * 2));
533+
}
534+
}
535+
536+
TEST_F(TensorPtrMakerTest, CreateArangeTensorWithNegativeStep) {
537+
auto tensor = arange(5, 0, -1);
538+
539+
EXPECT_EQ(tensor->dim(), 1);
540+
EXPECT_EQ(tensor->size(0), 5);
541+
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::Float);
542+
543+
for (auto i = 0; i < tensor->numel(); ++i) {
544+
auto val = tensor->const_data_ptr<float>()[i];
545+
EXPECT_EQ(val, static_cast<float>(5 - i));
546+
}
547+
}
548+
549+
TEST_F(TensorPtrMakerTest, CreateArangeTensorWithIntType) {
550+
auto tensor = arange(0, 5, 1, executorch::aten::ScalarType::Int);
551+
552+
EXPECT_EQ(tensor->dim(), 1);
553+
EXPECT_EQ(tensor->size(0), 5);
554+
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::Int);
555+
556+
for (auto i = 0; i < tensor->numel(); ++i) {
557+
auto val = tensor->const_data_ptr<int32_t>()[i];
558+
EXPECT_EQ(val, i);
559+
}
560+
}
561+
562+
TEST_F(TensorPtrMakerTest, CreateArangeTensorWithLongType) {
563+
auto tensor = arange(0, 5, 1, executorch::aten::ScalarType::Long);
564+
565+
EXPECT_EQ(tensor->dim(), 1);
566+
EXPECT_EQ(tensor->size(0), 5);
567+
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::Long);
568+
569+
for (auto i = 0; i < tensor->numel(); ++i) {
570+
auto val = tensor->const_data_ptr<int64_t>()[i];
571+
EXPECT_EQ(val, static_cast<int64_t>(i));
572+
}
573+
}
574+
575+
TEST_F(TensorPtrMakerTest, CreateArangeTensorWithDoubleType) {
576+
auto tensor = arange(0.5, 5.5, 0.5, executorch::aten::ScalarType::Double);
577+
578+
EXPECT_EQ(tensor->dim(), 1);
579+
EXPECT_EQ(tensor->size(0), 10); // (5.5-0.5)/0.5 = 10 elements
580+
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::Double);
581+
582+
for (auto i = 0; i < tensor->numel(); ++i) {
583+
auto val = tensor->const_data_ptr<double>()[i];
584+
EXPECT_DOUBLE_EQ(val, 0.5 + i * 0.5);
585+
}
586+
}
587+
588+
TEST_F(TensorPtrMakerTest, CreateArangeTensorWithEmptyRange) {
589+
// End < start with positive step should give empty tensor
590+
auto tensor = arange(5, 0, 1);
591+
592+
EXPECT_EQ(tensor->dim(), 1);
593+
EXPECT_EQ(tensor->size(0), 0);
594+
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::Float);
595+
}

0 commit comments

Comments
 (0)