Skip to content

Commit 3584da9

Browse files
authored
StaticAttentionIOManager: optional callback on logits from prefill
Differential Revision: D82150606 Pull Request resolved: #14336
1 parent 2f21092 commit 3584da9

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,9 @@ class StaticAttentionIOManager {
589589
size_t prefill(
590590
executorch::runtime::Span<TokenT> tokens,
591591
executorch::runtime::Span<TokenT> input_buffer,
592-
executorch::runtime::Method& method) {
592+
executorch::runtime::Method& method,
593+
std::function<void(executorch::runtime::Span<const float>)>
594+
logits_callback = nullptr) {
593595
ET_LOG(Info, "Prefilling at position %zu", input_pos_);
594596
size_t input_len = input_buffer.size();
595597
auto& masks = get_mask(input_buffer.size());
@@ -610,6 +612,13 @@ class StaticAttentionIOManager {
610612
config_.k_cache_output_indices,
611613
config_.v_cache_output_indices,
612614
batch_len);
615+
if (logits_callback) {
616+
auto logits_tensor = method.get_output(0).toTensor();
617+
auto* logits = logits_tensor.const_data_ptr<float>();
618+
logits_callback(executorch::runtime::Span(
619+
logits,
620+
logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1)));
621+
}
613622
}
614623
return batch_len - 1;
615624
}

0 commit comments

Comments
 (0)