Skip to content

Commit 6f07be3

Browse files
committed
Update on "[llm] Support different shape of input_pos"
For huggingface models, `forward()` is taking `tokens` as well as `cache_positions`, which is a list of cache indices. This is different than the .pte files `export_llama` gives, which are taking `tokens` and `input_pos` where `input_pos` is a scalar tensor. This PR adds support inside `text_decoder_runner.cpp` to handle both shapes of `input_pos`/`cache_positions`. To make the logic more generic without relying on extra metadata, here I'm adding the logic of inspecting method meta and input tensor info, to make a decision if we want to feed in `input_pos` or `cache_position`. Differential Revision: [D77203700](https://our.internmc.facebook.com/intern/diff/D77203700/) [ghstack-poisoned]
2 parents 9a698d7 + f5d5ae0 commit 6f07be3

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

examples/models/llava/runner/llava_text_decoder_runner.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,28 @@
1111
#pragma once
1212

1313
#include <executorch/extension/llm/runner/text_decoder_runner.h>
14+
#include <executorch/extension/tensor/tensor.h>
1415

1516
namespace example {
1617

1718
class ET_EXPERIMENTAL LlavaTextDecoderRunner
1819
: public executorch::extension::llm::TextDecoderRunner {
1920
public:
2021
explicit LlavaTextDecoderRunner(executorch::extension::Module* module)
21-
: TextDecoderRunner(module, true) {}
22+
: TextDecoderRunner(module) {}
2223

2324
inline executorch::runtime::Result<executorch::aten::Tensor> step(
2425
executorch::extension::TensorPtr& tokens,
25-
executorch::extension::TensorPtr& start_pos) override {
26+
int64_t start_pos) override {
2627
// run token embedding
2728
auto token_embedding_outputs =
2829
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, tokens));
2930

31+
auto start_pos_tensor = ::executorch::extension::from_blob(
32+
&start_pos, {1}, executorch::aten::ScalarType::Long);
3033
// run text model
3134
auto outputs_res = ET_UNWRAP(module_->execute(
32-
kTextModelMethod, {start_pos, token_embedding_outputs[0]}));
35+
kTextModelMethod, {start_pos_tensor, token_embedding_outputs[0]}));
3336

3437
ET_CHECK_MSG(
3538
outputs_res.size() == 1,

extension/llm/runner/text_decoder_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace llm {
2121

2222
class ET_EXPERIMENTAL TextDecoderRunner {
2323
public:
24-
TextDecoderRunner(Module* module);
24+
explicit TextDecoderRunner(Module* module);
2525

2626
virtual ~TextDecoderRunner() = default;
2727

runtime/executor/method_meta.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class TensorInfo final {
4444
/**
4545
* Returns the sizes of the tensor.
4646
*/
47-
Span<const ::executorch::aten::SizesType> sizes() const;
47+
Span<const int32_t> sizes() const;
4848

4949
/**
5050
* Returns the dim order of the tensor.

0 commit comments

Comments
 (0)