@@ -4510,6 +4510,75 @@ def mamba_mixer_forward(
45104510    return  contextualized_states 
45114511
45124512
4513+ def  falcon_mamba_mixer_forward (
4514+     self ,
4515+     input_states ,
4516+     cache_params = None ,
4517+     cache_position : Optional [torch .LongTensor ] =  None ,
4518+     attention_mask : Optional [torch .LongTensor ] =  None ,
4519+ ):
4520+     from  transformers .models .falcon_mamba .modeling_falcon_mamba  import  rms_forward 
4521+ 
4522+     batch_size , seq_len , _  =  input_states .shape 
4523+     dtype  =  input_states .dtype 
4524+     # 1. Gated MLP's linear projection 
4525+     projected_states  =  self .in_proj (input_states ).transpose (1 , 2 )  # [batch, 2 * intermediate_size, seq_len] 
4526+     hidden_states , gate  =  projected_states .chunk (2 , dim = 1 )
4527+ 
4528+     if  attention_mask  is  not None :
4529+         hidden_states  =  hidden_states  *  attention_mask .unsqueeze (1 )
4530+ 
4531+     # 2. Convolution sequence transformation 
4532+     if  cache_params  is  not None :
4533+         ssm_state  =  cache_params .ssm_states [self .layer_idx ].clone ()
4534+         ssm_state  =  ssm_state .to (hidden_states .device )
4535+         # use `cache_position.shape[0]` to check whether we are in prefill 
4536+         # stage, it's equivalent to check `cache_position[0] == 0`, which 
4537+         # breaks dynamo fullgraph constraints 
4538+         hidden_states , conv_state  =  self .conv_sequence_transform (
4539+             hidden_states , cache_position , cache_params .conv_states [self .layer_idx ]
4540+         )
4541+         cache_params .conv_states [self .layer_idx ] =  conv_state 
4542+     else :
4543+         ssm_state  =  torch .zeros (
4544+             (batch_size , self .intermediate_size , self .ssm_state_size ), device = hidden_states .device , dtype = dtype 
4545+         )
4546+         hidden_states  =  self .act (self .conv1d (hidden_states )[..., :seq_len ])  # [batch, intermediate_size, seq_len] 
4547+ 
4548+     if  attention_mask  is  not None :
4549+         hidden_states  =  hidden_states  *  attention_mask .unsqueeze (1 )
4550+ 
4551+     # 3. State Space Model sequence transformation 
4552+     # 3.a. Selection:  [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] 
4553+     ssm_parameters  =  self .x_proj (hidden_states .transpose (1 , 2 ))
4554+     time_step , B , C  =  torch .split (
4555+         ssm_parameters , [self .time_step_rank , self .ssm_state_size , self .ssm_state_size ], dim = - 1 
4556+     )
4557+ 
4558+     B  =  rms_forward (B , variance_epsilon = self .rms_eps )
4559+     C  =  rms_forward (C , variance_epsilon = self .rms_eps )
4560+     time_step  =  rms_forward (time_step , variance_epsilon = self .rms_eps )
4561+     discrete_time_step  =  self .dt_proj (time_step )  # [batch, seq_len, intermediate_size] 
4562+ 
4563+     discrete_time_step  =  torch .nn .functional .softplus (discrete_time_step )  # [batch, intermediate_size, seq_len] 
4564+     A  =  - torch .exp (self .A_log .float ())
4565+     B  =  B .float ()
4566+     D  =  self .D .float ()
4567+ 
4568+     scan_output , ssm_state  =  self .selective_scan (
4569+         ssm_state , hidden_states .float ().transpose (1 , 2 ), discrete_time_step , A , B , C , D 
4570+     )
4571+     scan_output  =  scan_output .transpose (1 , 2 )
4572+     scan_output  =  scan_output  *  self .act (gate )
4573+ 
4574+     if  cache_params  is  not None :
4575+         cache_params .ssm_states [self .layer_idx ].copy_ (ssm_state )
4576+ 
4577+     # 4. Final linear projection 
4578+     contextualized_states  =  self .out_proj (scan_output .transpose (1 , 2 ))  # [batch, seq_len, hidden_size] 
4579+     return  contextualized_states 
4580+ 
4581+ 
45134582class  MambaPatcher (ModelPatcher ):
45144583    def  __init__ (
45154584        self ,
@@ -4684,3 +4753,22 @@ def __exit__(self, exc_type, exc_value, traceback):
46844753        self ._model .forward  =  self ._model .__orig_forward 
46854754        for  layer  in  self ._model .backbone .layers :
46864755            layer .mixer .forward  =  layer .mixer ._orig_forward 
4756+ 
4757+ 
4758+ class  FalconMambaPatcher (MambaPatcher ):
4759+     def  __enter__ (self ):
4760+         super ().__enter__ ()
4761+         selective_scan  =  SelectiveScan ()
4762+ 
4763+         for  layer  in  self ._model .backbone .layers :
4764+             layer .mixer .selective_scan  =  selective_scan 
4765+             layer .mixer ._orig_forward  =  layer .mixer .forward 
4766+             layer .mixer .forward  =  types .MethodType (falcon_mamba_mixer_forward , layer .mixer )
4767+             conv_transform  =  ConvSequenceTransform (
4768+                 layer .mixer .conv_kernel_size ,
4769+                 layer .mixer .use_conv_bias ,
4770+                 layer .mixer .conv1d ,
4771+                 layer .mixer .act ,
4772+                 layer .mixer .conv1d .bias ,
4773+             )
4774+             layer .mixer .conv_sequence_transform  =  torch .jit .script (conv_transform )
0 commit comments