File tree Expand file tree Collapse file tree 1 file changed +10
-1
lines changed
examples/models/llama/runner Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -589,7 +589,9 @@ class StaticAttentionIOManager {
589589 size_t prefill (
590590 executorch::runtime::Span<TokenT> tokens,
591591 executorch::runtime::Span<TokenT> input_buffer,
592- executorch::runtime::Method& method) {
592+ executorch::runtime::Method& method,
593+ std::function<void (executorch::runtime::Span<const float >)>
594+ logits_callback = nullptr) {
593595 ET_LOG (Info, " Prefilling at position %zu" , input_pos_);
594596 size_t input_len = input_buffer.size ();
595597 auto & masks = get_mask (input_buffer.size ());
@@ -610,6 +612,13 @@ class StaticAttentionIOManager {
610612 config_.k_cache_output_indices ,
611613 config_.v_cache_output_indices ,
612614 batch_len);
615+ if (logits_callback) {
616+ auto logits_tensor = method.get_output (0 ).toTensor ();
617+ auto * logits = logits_tensor.const_data_ptr <float >();
618+ logits_callback (executorch::runtime::Span (
619+ logits,
620+ logits + batch_len * logits_tensor.size (logits_tensor.dim () - 1 )));
621+ }
613622 }
614623 return batch_len - 1 ;
615624 }
You can’t perform that action at this time.
0 commit comments