Skip to content

Commit 6520e06

Browse files
authored
Make type of logits a template parameter
Differential Revision: D84211619 Pull Request resolved: pytorch#14921
1 parent 38b51aa commit 6520e06

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -586,12 +586,12 @@ class StaticAttentionIOManager {
586586
* of the prompt and method's input length. Returns the position in the output
587587
* that corresponds to the end of the prompt during the last inference.
588588
*/
589-
template <typename TokenT>
589+
template <typename TokenT, typename LogitT>
590590
size_t prefill(
591591
executorch::runtime::Span<TokenT> tokens,
592592
executorch::runtime::Span<TokenT> input_buffer,
593593
executorch::runtime::Method& method,
594-
std::function<void(executorch::runtime::Span<const float>)>
594+
std::function<void(executorch::runtime::Span<const LogitT>)>
595595
logits_callback = nullptr) {
596596
ET_LOG(Info, "Prefilling at position %zu", input_pos_);
597597
size_t input_len = input_buffer.size();
@@ -619,7 +619,7 @@ class StaticAttentionIOManager {
619619
batch_len);
620620
if (logits_callback) {
621621
auto logits_tensor = method.get_output(0).toTensor();
622-
auto* logits = logits_tensor.const_data_ptr<float>();
622+
auto* logits = logits_tensor.const_data_ptr<LogitT>();
623623
logits_callback(executorch::runtime::Span(
624624
logits,
625625
logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1)));

0 commit comments

Comments
 (0)