File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
examples/models/llama/runner Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -586,12 +586,12 @@ class StaticAttentionIOManager {
586586 * of the prompt and method's input length. Returns the position in the output
587587 * that corresponds to the end of the prompt during the last inference.
588588 */
589- template <typename TokenT>
589+ template <typename TokenT, typename LogitT >
590590 size_t prefill (
591591 executorch::runtime::Span<TokenT> tokens,
592592 executorch::runtime::Span<TokenT> input_buffer,
593593 executorch::runtime::Method& method,
594- std::function<void (executorch::runtime::Span<const float >)>
594+ std::function<void (executorch::runtime::Span<const LogitT >)>
595595 logits_callback = nullptr) {
596596 ET_LOG (Info, " Prefilling at position %zu" , input_pos_);
597597 size_t input_len = input_buffer.size ();
@@ -619,7 +619,7 @@ class StaticAttentionIOManager {
619619 batch_len);
620620 if (logits_callback) {
621621 auto logits_tensor = method.get_output (0 ).toTensor ();
622- auto * logits = logits_tensor.const_data_ptr <float >();
622+ auto * logits = logits_tensor.const_data_ptr <LogitT >();
623623 logits_callback (executorch::runtime::Span (
624624 logits,
625625 logits + batch_len * logits_tensor.size (logits_tensor.dim () - 1 )));
You can’t perform that action at this time.
0 commit comments