Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
params_path=params_path,
use_int32_token=True if args.qnn else False,
use_kv_cache=args.use_kv_cache,
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
generate_full_logits=args.generate_full_logits,
Expand Down Expand Up @@ -746,6 +747,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901

def _load_llama_model_metadata(
weight_type: WeightType,
use_int32_token: bool,
use_kv_cache: bool,
use_sdpa_with_kv_cache: bool,
enable_dynamic_shape: bool,
Expand All @@ -759,6 +761,7 @@ def _load_llama_model_metadata(
"get_max_seq_len": model_args.max_seq_len,
"get_n_layers": model_args.n_layers,
"get_vocab_size": model_args.vocab_size,
"use_int32_token": use_int32_token,
"use_kv_cache": use_kv_cache,
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
"enable_dynamic_shape": enable_dynamic_shape,
Expand All @@ -779,6 +782,7 @@ def _load_llama_model(
checkpoint: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
params_path: str,
use_int32_token: bool = False,
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
generate_full_logits: bool = False,
Expand Down Expand Up @@ -852,6 +856,10 @@ def _load_llama_model(
else:
raise ValueError(f"Unsupported dtype {dtype}")

if use_int32_token:
token = example_inputs[0].to(torch.int32)
example_inputs = (token,) + example_inputs[1:]

return LLMEdgeManager(
model=model,
modelname=modelname,
Expand All @@ -870,6 +878,7 @@ def _load_llama_model(
verbose=verbose,
metadata=_load_llama_model_metadata(
weight_type,
use_int32_token,
use_kv_cache,
use_sdpa_with_kv_cache,
enable_dynamic_shape,
Expand Down
4 changes: 4 additions & 0 deletions examples/models/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
static constexpr auto kVocabSize = "get_vocab_size";
static constexpr auto kUseKVCache = "use_kv_cache";
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
static constexpr auto kUseInt32Token = "use_int32_token";
} // namespace

Runner::Runner(
Expand All @@ -51,6 +52,7 @@ Runner::Runner(
{kMaxSeqLen, 128},
{kUseKVCache, true},
{kUseSDPAWithKVCache, false},
{kUseInt32Token, true},
}) {
ET_LOG(
Info,
Expand Down Expand Up @@ -127,12 +129,14 @@ Error Runner::load() {
temperature_);
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
text_decoder_runner_.get(),
metadata_.at(kUseInt32Token),
metadata_.at(kUseKVCache),
metadata_.at(kEnableDynamicShape));

text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
tokenizer_.get(),
text_decoder_runner_.get(),
metadata_.at(kUseInt32Token),
metadata_.at(kUseKVCache),
std::move(eos_ids),
&stats_);
Expand Down
13 changes: 8 additions & 5 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ namespace llm {

TextPrefiller::TextPrefiller(
TextDecoderRunner* text_decoder_runner,
bool use_int32_token,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I meant more like

TextPrefiller::TextPrefiller(
    TextDecoderRunner* text_decoder_runner,
    bool use_kv_cache,
    bool enable_parallel_prefill,
    bool use_int32_token=False)

so we don't need to update all callsite...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about misunderstanding. I have updated. Thanks!

bool use_kv_cache,
bool enable_parallel_prefill)
: text_decoder_runner_(text_decoder_runner),
use_int32_token_(use_int32_token),
use_kv_cache_(use_kv_cache),
enable_parallel_prefill_(enable_parallel_prefill) {}

Expand All @@ -36,12 +38,13 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(

// store the token
uint64_t cur_token;
exec_aten::ScalarType token_type = use_int32_token_
? exec_aten::ScalarType::Int
: exec_aten::ScalarType::Long;
if (enable_parallel_prefill_ || !use_kv_cache_) {
// initialize tensor wrappers
auto tokens = from_blob(
prompt_tokens.data(),
{1, num_prompt_tokens},
exec_aten::ScalarType::Long);
auto tokens =
from_blob(prompt_tokens.data(), {1, num_prompt_tokens}, token_type);

auto start_pos_tensor =
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
Expand All @@ -60,7 +63,7 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
cur_token = prompt_tokens[0];

// initialize tensor wrappers
auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long);
auto tokens = from_blob(&cur_token, {1, 1}, token_type);

auto start_pos_tensor =
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
Expand Down
2 changes: 2 additions & 0 deletions extension/llm/runner/text_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ET_EXPERIMENTAL TextPrefiller {
public:
TextPrefiller(
TextDecoderRunner* text_decoder_runner,
bool use_int32_token,
bool use_kv_cache_,
bool enable_parallel_prefill);
/**
Expand All @@ -40,6 +41,7 @@ class ET_EXPERIMENTAL TextPrefiller {

private:
TextDecoderRunner* text_decoder_runner_;
bool use_int32_token_;
bool use_kv_cache_;
bool enable_parallel_prefill_;
};
Expand Down
9 changes: 7 additions & 2 deletions extension/llm/runner/text_token_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ class ET_EXPERIMENTAL TextTokenGenerator {
TextTokenGenerator(
Tokenizer* tokenizer,
TextDecoderRunner* text_decoder_runner,
bool use_int32_token,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is causing BC breaking and causes CI failure. Can we set a default value so it's BC compatible? Also assert when it doesn't match metadata.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CI breaking stack trace can be found here

bool use_kv_cache,
std::unique_ptr<std::unordered_set<uint64_t>>&& eos_ids,
Stats* stats)
: tokenizer_(tokenizer),
text_decoder_runner_(text_decoder_runner),
eos_ids_(std::move(eos_ids)),
use_int32_token_(use_int32_token),
use_kv_cache_(use_kv_cache),
stats_(stats) {}

Expand All @@ -54,6 +56,9 @@ class ET_EXPERIMENTAL TextTokenGenerator {

std::vector<uint64_t> token_data; // allocate space for the tokens
std::vector<executorch::aten::SizesType> token_shape;
exec_aten::ScalarType token_type = use_int32_token_
? exec_aten::ScalarType::Int
: exec_aten::ScalarType::Long;

// Token after prefill
uint64_t cur_token = tokens.back();
Expand All @@ -70,8 +75,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
}

// initialize tensor wrappers
auto tokens_managed = from_blob(
token_data.data(), token_shape, executorch::aten::ScalarType::Long);
auto tokens_managed = from_blob(token_data.data(), token_shape, token_type);
auto start_pos_managed =
from_blob(&pos, {1}, executorch::aten::ScalarType::Long);

Expand Down Expand Up @@ -133,6 +137,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
Tokenizer* tokenizer_;
TextDecoderRunner* text_decoder_runner_;
std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_;
bool use_int32_token_;
bool use_kv_cache_;

// state machine
Expand Down
Loading