@@ -25,13 +25,42 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, DeviceSpan<
25
25
}
26
26
27
27
DeviceSpan<float > DecoderOnly_State::Run (int total_length, DeviceSpan<int32_t >& next_tokens, DeviceSpan<int32_t > next_indices) {
28
- UpdateInputsOutputs (next_tokens, next_indices, total_length);
28
+ size_t num_tokens = next_tokens.size ();
29
+ const size_t chunk_size = 15 ;
30
+
31
+ if (num_tokens > chunk_size) {
32
+ // Chunking logic for context phase - process in chunks of 512 tokens
33
+ size_t processed_tokens = 0 ;
34
+ int length = total_length - static_cast <int >(num_tokens);
35
+ while (processed_tokens < num_tokens) {
36
+ size_t current_chunk_size = std::min (chunk_size, num_tokens - processed_tokens);
37
+
38
+ // Create subspans for current chunk
39
+ auto chunk_tokens = next_tokens.subspan (processed_tokens, current_chunk_size);
40
+ // auto chunk_indices = next_indices.subspan(processed_tokens, current_chunk_size);
41
+ length = length + static_cast <int >(current_chunk_size);
42
+ // Process this chunk - fills KV cache progressively
43
+ UpdateInputsOutputs (chunk_tokens, next_indices, length);
44
+
45
+ // Graph capture is typically disabled during context phase chunking
46
+ bool graph_capture_this_run = false ; // Disable graph capture during chunking
47
+ State::Run (*model_.session_decoder_ , graph_capture_this_run);
48
+
49
+ processed_tokens += current_chunk_size;
50
+ }
51
+
52
+ // Return logits from the last chunk for potential sampling
53
+ return logits_.Get ();
54
+ } else {
55
+ // Original logic for tokens <= 512 (generation phase or small context)
56
+ UpdateInputsOutputs (next_tokens, next_indices, total_length);
29
57
30
- // Graph capture enabled for token generation case, allowing it to repeat the same graph for each token.
31
- bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape ()[1 ] == 1 ;
32
- State::Run (*model_.session_decoder_ , graph_capture_this_run);
58
+ // Graph capture enabled for token generation case, allowing it to repeat the same graph for each token.
59
+ bool graph_capture_this_run = params_->use_graph_capture && input_ids_.GetShape ()[1 ] == 1 ;
60
+ State::Run (*model_.session_decoder_ , graph_capture_this_run);
33
61
34
- return logits_.Get ();
62
+ return logits_.Get ();
63
+ }
35
64
}
36
65
37
66
void DecoderOnly_State::RewindTo (size_t index) {
0 commit comments