@@ -92,6 +92,7 @@ def __init__(
9292 self .use_mem_eff_path = use_mem_eff_path
9393 self .layer_idx = layer_idx
9494 self .ssm_state = None
95+ self .conv_state = None
9596
9697 # Order: [z, x, B, C, dt]
9798 d_in_proj = 2 * self .d_inner + 2 * self .ngroups * self .d_state + self .nheads
@@ -161,7 +162,6 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
161162 Returns: same shape as u
162163 """
163164 seqlen_og = seqlen
164- #import pdb; pdb.set_trace()
165165 cache_device = self .in_proj .weight .device
166166 cache_dtype = self .in_proj .weight .dtype
167167 if seqlen is None :
@@ -170,18 +170,11 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
170170 batch_seqlen , dim = u .shape
171171 batch = batch_seqlen // seqlen
172172
173- should_cache_states = inference_params is None and cu_seqlens is None
173+ should_cache_states = inference_params is not None
174174 cached_state = self ._maybe_get_cached_state (batch , cache_device , cache_dtype ) if should_cache_states else None
175175
176176 conv_state , ssm_state = None , None
177- if inference_params is not None :
178- inference_batch = cu_seqlens .shape [0 ] - 1 if cu_seqlens is not None else batch
179- conv_state , ssm_state = self ._get_states_from_cache (inference_params , inference_batch )
180- if inference_params .seqlen_offset > 0 :
181- # The states are updated inplace
182- out , _ , _ = self .step (u , conv_state , ssm_state )
183- return out
184-
177+
185178 zxbcdt = self .in_proj (u ) # (B, L, d_in_proj) or (B * L, d_in_proj)
186179 if seqlen_og is not None :
187180 zxbcdt = rearrange (zxbcdt , "(b l) d -> b l d" , l = seqlen )
@@ -227,34 +220,37 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
227220 [d_mlp , d_mlp , self .d_ssm , self .d_ssm + 2 * self .ngroups * self .d_state , self .nheads ],
228221 dim = - 1
229222 )
223+
224+ assert self .activation in ["silu" , "swish" ]
225+ conv_state = self ._prepare_conv_state (xBC .transpose (1 , 2 ), batch )
230226 if conv_state is not None :
231227 if cu_seqlens is None :
232228 # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
233229 # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
234230 xBC_t = rearrange (xBC , "b l d -> b d l" )
235- conv_state . copy_ ( F .pad (xBC_t , (self .d_conv - xBC_t .shape [- 1 ], 0 ) )) # Update state (B D W)
231+ self . conv_state = F .pad (xBC_t , (self .d_conv - xBC_t .shape [- 1 ], 0 )) # Update state (B D W)
236232 else :
237233 assert causal_conv1d_varlen_states is not None , "varlen inference requires causal_conv1d package"
238234 assert batch == 1 , "varlen inference only supports batch dimension 1"
239235 conv_varlen_states = causal_conv1d_varlen_states (
240236 xBC .squeeze (0 ), cu_seqlens , state_len = conv_state .shape [- 1 ]
241237 )
242238 conv_state .copy_ (conv_varlen_states )
243- assert self .activation in ["silu" , "swish" ]
244239 if causal_conv1d_fn is None or self .activation not in ["silu" , "swish" ]:
245240 assert seq_idx is None , "varlen conv1d requires the causal_conv1d package"
246241 xBC = self .act (
247242 self .conv1d (xBC .transpose (1 , 2 )).transpose (1 , 2 )[:, :- (self .d_conv - 1 )]
248243 ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
249244 else :
250- xBC = causal_conv1d_fn (
245+ xBC = causal_conv1d_update (
251246 xBC .transpose (1 , 2 ),
252- rearrange (self .conv1d .weight , "d 1 w -> d w" ),
247+ conv_state = conv_state ,
248+ weight = rearrange (self .conv1d .weight , "d 1 w -> d w" ),
253249 bias = self .conv1d .bias ,
254250 activation = self .activation ,
255- seq_idx = seq_idx ,
256251 ).transpose (1 , 2 )
257252 x , B , C = torch .split (xBC , [self .d_ssm , self .ngroups * self .d_state , self .ngroups * self .d_state ], dim = - 1 )
253+
258254 return_varlen_states = cu_seqlens is not None and inference_params is not None
259255 initial_states = ssm_state if ssm_state is not None else cached_state
260256 return_final_states = (ssm_state is not None ) or should_cache_states
@@ -409,6 +405,7 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
409405
410406 def reset_cache (self ):
411407 self .ssm_state = None
408+ self .conv_state = None
412409
413410 def _maybe_get_cached_state (self , batch_size , device , dtype ):
414411 if self .ssm_state is None :
@@ -427,3 +424,17 @@ def _update_cache_state(self, new_state):
427424 if cache_state .dtype != target_dtype :
428425 cache_state = cache_state .to (dtype = target_dtype )
429426 self .ssm_state = cache_state .contiguous ()
427+
428+ def _prepare_conv_state (self , x , batch_size ):
429+ state_len = self .d_conv - 1
430+ if state_len <= 0 :
431+ self .conv_state = None
432+ return None
433+ if (
434+ self .conv_state is None
435+ or self .conv_state .shape [0 ] != batch_size
436+ or self .conv_state .shape [1 ] != x .shape [1 ]
437+ ):
438+ self .conv_state = x .new_zeros (batch_size , x .shape [1 ], state_len )
439+ return self .conv_state
440+
0 commit comments