1+ import logging
12import math
23
3- import causal_conv1d
44import einops
55import mamba_ssm .ops .triton .ssd_combined
66import torch
99from fast_llm .layers .common .linear import Linear
1010from fast_llm .layers .ssm .config import SSMConfig , SSMDimNames
1111from fast_llm .tensor import ParameterMeta , init_ones_ , init_uniform_ , init_zeros_ , kaiming_init_
12+ from fast_llm .utils import get_lr_scale
13+
14+ logger = logging .getLogger (__name__ )
15+
16+ try :
17+ import causal_conv1d
18+ except ImportError :
19+ # this is needed since we cannot use causal_conv1d on B200 GPUs for now
20+ logger .warning ("Note, causal_conv1d not found, will use torch.nn.functional.conv1d instead" )
21+ causal_conv1d = None
1222
1323"""
1424This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py
@@ -44,6 +54,9 @@ def __init__(
4454 bias = config .add_bias_linear
4555 self .layer_idx = layer_idx
4656 self ._return_input = return_input
57+ layer_lr_scale = config .per_layer_lr_scale [layer_idx ] if config .per_layer_lr_scale else None
58+ mamba_layer_lr_scale = get_lr_scale (self .config .mamba_lr_scale , layer_lr_scale )
59+ logger .info (f"Setting lr_scale for layer { layer_idx } of type { type (self )} : { mamba_layer_lr_scale } " )
4760
4861 td_inner = tensor_space .get_tensor_dim (SSMDimNames .inner_dim )
4962 td_state = tensor_space .get_tensor_dim (SSMDimNames .state_dim )
@@ -67,31 +80,41 @@ def __init__(
6780
6881 # TODO: double check initializations
6982 # Projections
70- self .in_proj = Linear (td_model , td_inner_proj , bias = bias , weight_init_method = kaiming_init_ (td_model .size ))
83+ self .in_proj = Linear (
84+ td_model ,
85+ td_inner_proj ,
86+ bias = bias ,
87+ weight_init_method = kaiming_init_ (td_model .size ),
88+ lr_scale = mamba_layer_lr_scale ,
89+ )
7190 self .z_bias = (
7291 ParameterMeta .from_dims (
7392 (td_inner ,),
7493 weight_decay = False ,
7594 init_method = init_zeros_ ,
95+ lr_scale = mamba_layer_lr_scale ,
7696 )
7797 if not bias
7898 else 0.0
7999 )
80100
81- # Convolutional layer
82101 self .conv1d_weight = ParameterMeta .from_dims (
83102 (td_conv , TensorDim ("1" , 1 ), td_conv_kernel ),
84103 init_method = init_uniform_ (
85104 1 / math .sqrt (td_conv .size * td_conv_kernel .size ), 1 / math .sqrt (td_conv .size * td_conv_kernel .size )
86105 ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67
106+ lr_scale = mamba_layer_lr_scale ,
107+ )
108+ self .conv1d_bias = ParameterMeta .from_dims (
109+ (td_conv ,), init_method = bias_init_method (self .conv1d_weight ), lr_scale = mamba_layer_lr_scale
87110 )
88- self .conv1d_bias = ParameterMeta .from_dims ((td_conv ,), init_method = bias_init_method (self .conv1d_weight ))
89111
90112 # D "skip" parameter
91113 self .D = ParameterMeta .from_dims (
92114 (td_n_qk_heads ,),
93115 weight_decay = False ,
94116 init_method = init_ones_ ,
117+ lr_scale = mamba_layer_lr_scale ,
95118 )
96119
97120 # out_proj
@@ -100,6 +123,7 @@ def __init__(
100123 td_model ,
101124 bias = bias ,
102125 weight_init_method = kaiming_init_ (td_inner .size ),
126+ lr_scale = mamba_layer_lr_scale ,
103127 )
104128
105129 @property
@@ -210,10 +234,25 @@ def forward(self, hidden_states, kwargs):
210234
211235 def convolutional_forward (self , xBC , padded_len ):
212236 """Convolutional layer forward pass for the full sequence."""
213- xBC = causal_conv1d .causal_conv1d_fn (
214- xBC .transpose (1 , 2 ),
215- einops .rearrange (self .conv1d_weight , "d 1 w -> d w" ),
216- self .conv1d_bias ,
217- activation = None if self .activation_name == "identity" else self .activation_name ,
218- ).transpose (1 , 2 )
237+ if causal_conv1d is None or self .activation_name not in [
238+ "silu" ,
239+ "swish" ,
240+ "identity" ,
241+ ]:
242+ xBC = self .act (
243+ torch .nn .functional .conv1d (
244+ xBC .transpose (1 , 2 ),
245+ self .conv1d_weight ,
246+ bias = self .conv1d_bias ,
247+ groups = self .conv1d_weight .shape [0 ],
248+ padding = self .conv_kernel_size - 1 ,
249+ )[..., :padded_len ].transpose (1 , 2 )
250+ )
251+ else :
252+ xBC = causal_conv1d .causal_conv1d_fn (
253+ xBC .transpose (1 , 2 ),
254+ einops .rearrange (self .conv1d_weight , "d 1 w -> d w" ),
255+ self .conv1d_bias ,
256+ activation = None if self .activation_name == "identity" else self .activation_name ,
257+ ).transpose (1 , 2 )
219258 return xBC
0 commit comments