@@ -144,8 +144,15 @@ def _triton_cached_ssm(
144144 num_seq = num_prefill + num_decode
145145 num_total_tokens = num_prefill_tokens + num_decode
146146
147- y_prefill = None
148- y_decode = None
147+ # Preallocate output tensor to avoid memcpy cost for merging prefill
148+ # and decode outputs
149+ preallocated_ssm_out = torch .empty (
150+ [bs , num_heads , head_dim ],
151+ dtype = hidden_states .dtype ,
152+ device = hidden_states .device ,
153+ )
154+ preallocated_ssm_out_p = preallocated_ssm_out [:num_prefill_tokens ]
155+ preallocated_ssm_out_d = preallocated_ssm_out [num_prefill_tokens :num_total_tokens ]
149156
150157 # Prefill: concatenate tokens at the front and run combined scan
151158 if num_prefill > 0 :
@@ -165,7 +172,7 @@ def _triton_cached_ssm(
165172 chunk_indices = None
166173 chunk_offsets = None
167174
168- y_prefill , varlen_states = mamba_chunk_scan_combined (
175+ varlen_states = mamba_chunk_scan_combined (
169176 hs_prefill ,
170177 dt_prefill ,
171178 A ,
@@ -184,11 +191,12 @@ def _triton_cached_ssm(
184191 dt_limit = (time_step_limit [0 ], time_step_limit [1 ]),
185192 return_final_states = False ,
186193 return_varlen_states = True ,
187- mamba_ssm_cache_dtype = ssm_state_cache .dtype ,
194+ out = preallocated_ssm_out_p .unsqueeze (0 ),
195+ state_dtype = ssm_state_cache .dtype ,
188196 )
189197
190198 ssm_state_cache .index_copy_ (
191- 0 , slot_idx [:num_prefill ], varlen_states .to (ssm_state_cache .dtype )
199+ 0 , slot_idx [:num_prefill ]. long () , varlen_states .to (ssm_state_cache .dtype )
192200 )
193201
194202 # Decode: batch single-token updates via selective_state_update
@@ -205,7 +213,7 @@ def _triton_cached_ssm(
205213 A_full = A [..., None , None ].expand (num_heads , head_dim , ssm_state_size )
206214 D_full = D [..., None ].expand (num_heads , head_dim )
207215
208- y_decode = selective_state_update (
216+ selective_state_update (
209217 ssm_state_cache ,
210218 x_decode ,
211219 dt_hp ,
@@ -217,19 +225,16 @@ def _triton_cached_ssm(
217225 dt_bias = dt_bias_hp ,
218226 dt_softplus = True ,
219227 state_batch_indices = slot_idx_decode ,
220- ) # [nd, H, D]
221-
222- # Dispatch return logic
223- if num_prefill > 0 and num_decode > 0 :
224- y = torch .empty_like (hidden_states , memory_format = torch .contiguous_format )
225- y_flat = y .view (bs , * y .shape [2 :])
226- y_flat [:num_prefill_tokens ].copy_ (y_prefill [0 ])
227- y_flat [num_prefill_tokens :num_total_tokens ].copy_ (y_decode )
228- return y
229- elif num_prefill > 0 :
230- return y_prefill [0 ].view (b , s , num_heads , head_dim ).to (hidden_states .dtype )
231- elif num_decode > 0 :
232- return y_decode .view (b , s , num_heads , head_dim ).to (hidden_states .dtype )
228+ out = preallocated_ssm_out_d ,
229+ )
230+
231+ # Return the preallocated output reshaped to original dimensions
232+ if num_total_tokens > 0 :
233+ return (
234+ preallocated_ssm_out [:num_total_tokens ]
235+ .view (b , s , num_heads , head_dim )
236+ .to (hidden_states .dtype )
237+ )
233238 else :
234239 return torch .empty_like (hidden_states )
235240
0 commit comments