Skip to content

Commit ce7f5a0

Browse files
authored
[llama] Fix text prefiller
Differential Revision: D61069478 Pull Request resolved: #4660
1 parent c9e7714 commit ce7f5a0

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

extension/llm/runner/text_prefiller.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,31 @@ Result<uint64_t> TextPrefiller::prefill(
7878

7979
ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long);
8080

81+
// run the first token and get back logits tensor. Assuming the first token
82+
// is bos so don't callback.
83+
exec_aten::Tensor logits_tensor = ET_UNWRAP(
84+
text_decoder_runner_->step(managed_tokens, managed_start_pos));
85+
pos = 1; // start from index 1
86+
8187
while (pos < num_prompt_tokens) {
8288
// Run the model
8389
pos_data = start_pos + pos;
8490

85-
Result<exec_aten::Tensor> logits_res =
86-
text_decoder_runner_->step(managed_tokens, managed_start_pos);
87-
88-
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
8991
prev_token = cur_token;
9092

91-
pos++;
93+
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
94+
cur_token = prompt_tokens[pos];
9295

93-
cur_token = pos == num_prompt_tokens
94-
? text_decoder_runner_->logits_to_token(logits_res.get())
95-
: prompt_tokens[pos];
96+
logits_tensor = ET_UNWRAP(
97+
text_decoder_runner_->step(managed_tokens, managed_start_pos));
9698

9799
// print the token as string, decode it with the Tokenizer object
98100
token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));
101+
102+
pos++;
99103
}
104+
105+
cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
100106
}
101107
return cur_token;
102108
}

0 commit comments

Comments
 (0)