@@ -28,13 +28,42 @@ void DecoderOnly_State::SetExtraInputs(const std::vector<ExtraInput>& extra_inpu
28
28
}
29
29
30
30
DeviceSpan<float > DecoderOnly_State::Run (int total_length, DeviceSpan<int32_t >& next_tokens, DeviceSpan<int32_t > next_indices) {
31
- UpdateInputsOutputs (next_tokens, next_indices, total_length);
31
+ size_t num_tokens = next_tokens.size ();
32
+ const size_t chunk_size = 15 ;
33
+
34
+ if (num_tokens > chunk_size) {
35
+ // Chunking logic for context phase - process in chunks of 512 tokens
36
+ size_t processed_tokens = 0 ;
37
+ int length = total_length - static_cast <int >(num_tokens);
38
+ while (processed_tokens < num_tokens) {
39
+ size_t current_chunk_size = std::min (chunk_size, num_tokens - processed_tokens);
40
+
41
+ // Create subspans for current chunk
42
+ auto chunk_tokens = next_tokens.subspan (processed_tokens, current_chunk_size);
43
+ // auto chunk_indices = next_indices.subspan(processed_tokens, current_chunk_size);
44
+ length = length + static_cast <int >(current_chunk_size);
45
+ // Process this chunk - fills KV cache progressively
46
+ UpdateInputsOutputs (chunk_tokens, next_indices, length);
47
+
48
+ // Graph capture is typically disabled during context phase chunking
49
+ bool graph_capture_this_run = false ; // Disable graph capture during chunking
50
+ State::Run (*model_.session_decoder_ , graph_capture_this_run);
51
+
52
+ processed_tokens += current_chunk_size;
53
+ }
54
+
55
+ // Return logits from the last chunk for potential sampling
56
+ return logits_.Get ();
57
+ } else {
58
+ // Original logic for tokens <= 512 (generation phase or small context)
59
+ UpdateInputsOutputs (next_tokens, next_indices, total_length);
32
60
33
- // Graph capture enabled for token generation case, allowing it to repeat the same graph for each token.
34
- bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape ()[1 ] == 1 ;
35
- State::Run (*model_.session_decoder_ , graph_capture_this_run);
61
+ // Graph capture enabled for token generation case, allowing it to repeat the same graph for each token.
62
+ bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape ()[1 ] == 1 ;
63
+ State::Run (*model_.session_decoder_ , graph_capture_this_run);
36
64
37
- return logits_.Get ();
65
+ return logits_.Get ();
66
+ }
38
67
}
39
68
40
69
void DecoderOnly_State::RewindTo (size_t index) {
0 commit comments