File tree Expand file tree Collapse file tree 1 file changed +10
-1
lines changed
examples/models/llama/runner Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -589,7 +589,9 @@ class StaticAttentionIOManager {
589
589
size_t prefill (
590
590
executorch::runtime::Span<TokenT> tokens,
591
591
executorch::runtime::Span<TokenT> input_buffer,
592
- executorch::runtime::Method& method) {
592
+ executorch::runtime::Method& method,
593
+ std::function<void (executorch::runtime::Span<const float >)>
594
+ logits_callback = nullptr) {
593
595
ET_LOG (Info, " Prefilling at position %zu" , input_pos_);
594
596
size_t input_len = input_buffer.size ();
595
597
auto & masks = get_mask (input_buffer.size ());
@@ -610,6 +612,13 @@ class StaticAttentionIOManager {
610
612
config_.k_cache_output_indices ,
611
613
config_.v_cache_output_indices ,
612
614
batch_len);
615
+ if (logits_callback) {
616
+ auto logits_tensor = method.get_output (0 ).toTensor ();
617
+ auto * logits = logits_tensor.const_data_ptr <float >();
618
+ logits_callback (executorch::runtime::Span (
619
+ logits,
620
+ logits + batch_len * logits_tensor.size (logits_tensor.dim () - 1 )));
621
+ }
613
622
}
614
623
return batch_len - 1 ;
615
624
}
You can’t perform that action at this time.
0 commit comments