@@ -78,25 +78,31 @@ Result<uint64_t> TextPrefiller::prefill(
7878
7979 ManagedTensor managed_start_pos (&pos_data, {1 }, ScalarType::Long);
8080
81+ // run the first token and get back logits tensor. Assuming the first token
82+ // is bos so don't callback.
83+ exec_aten::Tensor logits_tensor = ET_UNWRAP (
84+ text_decoder_runner_->step (managed_tokens, managed_start_pos));
85+ pos = 1 ; // start from index 1
86+
8187 while (pos < num_prompt_tokens) {
8288 // Run the model
8389 pos_data = start_pos + pos;
8490
85- Result<exec_aten::Tensor> logits_res =
86- text_decoder_runner_->step (managed_tokens, managed_start_pos);
87-
88- ET_CHECK_OK_OR_RETURN_ERROR (logits_res.error ());
8991 prev_token = cur_token;
9092
91- pos++;
93+ // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
94+ cur_token = prompt_tokens[pos];
9295
93- cur_token = pos == num_prompt_tokens
94- ? text_decoder_runner_->logits_to_token (logits_res.get ())
95- : prompt_tokens[pos];
96+ logits_tensor = ET_UNWRAP (
97+ text_decoder_runner_->step (managed_tokens, managed_start_pos));
9698
9799 // print the token as string, decode it with the Tokenizer object
98100 token_callback (ET_UNWRAP (tokenizer_->decode (prev_token, cur_token)));
101+
102+ pos++;
99103 }
104+
105+ cur_token = text_decoder_runner_->logits_to_token (logits_tensor);
100106 }
101107 return cur_token;
102108}
0 commit comments