Skip to content

Commit 9673ff4

Browse files
authored
Experimental: Add support for prefix-chunking
1 parent 497bd94 commit 9673ff4

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

src/models/decoder_only.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,42 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, DeviceSpan<
2525
}
2626

2727
DeviceSpan<float> DecoderOnly_State::Run(int total_length, DeviceSpan<int32_t>& next_tokens, DeviceSpan<int32_t> next_indices) {
28-
UpdateInputsOutputs(next_tokens, next_indices, total_length);
28+
size_t num_tokens = next_tokens.size();
29+
const size_t chunk_size = 15;
30+
31+
if (num_tokens > chunk_size) {
32+
// Chunking logic for context phase - process in chunks of 512 tokens
33+
size_t processed_tokens = 0;
34+
int length = total_length - static_cast<int>(num_tokens);
35+
while (processed_tokens < num_tokens) {
36+
size_t current_chunk_size = std::min(chunk_size, num_tokens - processed_tokens);
37+
38+
// Create subspans for current chunk
39+
auto chunk_tokens = next_tokens.subspan(processed_tokens, current_chunk_size);
40+
//auto chunk_indices = next_indices.subspan(processed_tokens, current_chunk_size);
41+
length = length + static_cast<int>(current_chunk_size);
42+
// Process this chunk - fills KV cache progressively
43+
UpdateInputsOutputs(chunk_tokens, next_indices, length);
44+
45+
// Graph capture is typically disabled during context phase chunking
46+
bool graph_capture_this_run = false; // Disable graph capture during chunking
47+
State::Run(*model_.session_decoder_, graph_capture_this_run);
48+
49+
processed_tokens += current_chunk_size;
50+
}
51+
52+
// Return logits from the last chunk for potential sampling
53+
return logits_.Get();
54+
} else {
55+
// Original logic for tokens <= 512 (generation phase or small context)
56+
UpdateInputsOutputs(next_tokens, next_indices, total_length);
2957

30-
// Graph capture enabled for token generation case, allowing it to repeat the same graph for each token.
31-
bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape()[1] == 1;
32-
State::Run(*model_.session_decoder_, graph_capture_this_run);
58+
// Graph capture enabled for token generation case, allowing it to repeat the same graph for each token.
59+
bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape()[1] == 1;
60+
State::Run(*model_.session_decoder_, graph_capture_this_run);
3361

34-
return logits_.Get();
62+
return logits_.Get();
63+
}
3564
}
3665

3766
void DecoderOnly_State::RewindTo(size_t index) {

0 commit comments

Comments
 (0)