Skip to content

Commit 3936fa9

Browse files
committed
Experimental: Add support for prefix-chunking
1 parent 9b80483 commit 3936fa9

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

src/models/decoder_only.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,42 @@ void DecoderOnly_State::SetExtraInputs(const std::vector<ExtraInput>& extra_inpu
2828
}
2929

3030
DeviceSpan<float> DecoderOnly_State::Run(int total_length, DeviceSpan<int32_t>& next_tokens, DeviceSpan<int32_t> next_indices) {
31-
UpdateInputsOutputs(next_tokens, next_indices, total_length);
31+
size_t num_tokens = next_tokens.size();
32+
const size_t chunk_size = 15;
33+
34+
if (num_tokens > chunk_size) {
35+
// Chunking logic for context phase - process in chunks of 512 tokens
36+
size_t processed_tokens = 0;
37+
int length = total_length - static_cast<int>(num_tokens);
38+
while (processed_tokens < num_tokens) {
39+
size_t current_chunk_size = std::min(chunk_size, num_tokens - processed_tokens);
40+
41+
// Create subspans for current chunk
42+
auto chunk_tokens = next_tokens.subspan(processed_tokens, current_chunk_size);
43+
//auto chunk_indices = next_indices.subspan(processed_tokens, current_chunk_size);
44+
length = length + static_cast<int>(current_chunk_size);
45+
// Process this chunk - fills KV cache progressively
46+
UpdateInputsOutputs(chunk_tokens, next_indices, length);
47+
48+
// Graph capture is typically disabled during context phase chunking
49+
bool graph_capture_this_run = false; // Disable graph capture during chunking
50+
State::Run(*model_.session_decoder_, graph_capture_this_run);
51+
52+
processed_tokens += current_chunk_size;
53+
}
54+
55+
// Return logits from the last chunk for potential sampling
56+
return logits_.Get();
57+
} else {
58+
// Original logic for tokens <= 512 (generation phase or small context)
59+
UpdateInputsOutputs(next_tokens, next_indices, total_length);
3260

33-
// Graph capture enabled for token generation case, allowing it to repeat the same graph for each token.
34-
bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape()[1] == 1;
35-
State::Run(*model_.session_decoder_, graph_capture_this_run);
61+
// Graph capture enabled for token generation case, allowing it to repeat the same graph for each token.
62+
bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape()[1] == 1;
63+
State::Run(*model_.session_decoder_, graph_capture_this_run);
3664

37-
return logits_.Get();
65+
return logits_.Get();
66+
}
3867
}
3968

4069
void DecoderOnly_State::RewindTo(size_t index) {

0 commit comments

Comments
 (0)