1515from typing import Dict
1616
1717# Third Party
18- from mamba_ssm .modules .mamba2_cp import Mamba2CP
18+ try :
19+ from mamba_ssm .modules .mamba2_cp import Mamba2CP
20+ except ImportError :
21+ raise ValueError ("custom mamba_ssm package installation is needed"
22+ "install from https://github.com/garrett361/mamba/tree/mamba-cp"
23+ )
24+ from accelerate .logging import get_logger
1925
2026# pylint: disable=import-error
2127from torch .distributed ._tensor .device_mesh import init_device_mesh
2531
2632# to avoid rechunking/sharding of the buffers
2733# ideally this is not optimal
34+ # this is done to make self attention cp compatible with mamba cp
2835from torch .distributed .tensor .experimental ._attention import _cp_options
2936_cp_options .enable_load_balance = False
3037
38+ logger = get_logger (__name__ )
3139
40+ # the same keys are used in accelerate
41+ # therefore we choose these to be in sync and cross leverage.
3242key_cp = "cp"
3343key_rep = "dp_shard"
3444
35-
36- def hf_config_ssm_config (hf_config ) -> Dict :
45+ # extract ssm config from hf config to be used
46+ # while swapping the mamba modules
47+ def get_ssmconfig_from_hfconfig (hf_config ) -> Dict :
3748 config_ssm = {}
3849 config_ssm ["d_model" ] = hf_config .hidden_size
3950 config_ssm ["d_state" ] = 128
@@ -45,6 +56,7 @@ def hf_config_ssm_config(hf_config) -> Dict:
4556 return config_ssm
4657
4758
59+ # to patch input arguments between mamba cp module and standard hf mamba module
4860class Mamba2CPHF (Mamba2CP ):
4961 def forward (
5062 self ,
@@ -63,7 +75,10 @@ def forward(
6375 inference_params = None ,
6476 )
6577
66-
78+ # patches each mamba module with mamba cp module
79+ # mamba cp module's weights are exactly same as hf mamba module
80+ # so we reuse the state dict and the same does not need special handling
81+ # while checkpointing.
6782def patch_mamba_layers_with_cp_head (
6883 model ,
6984 checkpoint_name_or_path ,
@@ -74,12 +89,18 @@ def patch_mamba_layers_with_cp_head(
7489 cp_mamba_recompute ,
7590):
7691
77- config_ssm = hf_config_ssm_config (model .config )
92+ config_ssm = get_ssmconfig_from_hfconfig (model .config )
7893 device = torch .device (f"cuda:{ rank } " )
7994 if is_fsdp_enabled ():
8095 device = torch .device ("cpu" )
8196 rep_size = world_size // cp_degree
82-
97+
98+ # auto infer ddp and cp ranks
99+ # does not work on other combination of parallelisms
100+ logger .warning (
101+ "Mamba CP is only meant for parallelism combinations having DP and CP"
102+ "other combinations can lead to unexpected behaviour"
103+ )
83104 if cp_degree == 1 :
84105 raise ValueError ("CP degree can't be one" )
85106 if rep_size == 1 :
@@ -100,7 +121,6 @@ def patch_mamba_layers_with_cp_head(
100121 "cp_mamba_impl" : cp_mamba_impl ,
101122 "cp_mamba_recompute" : cp_mamba_recompute ,
102123 }
103-
104124 with torch .no_grad ():
105125 dtype = model .dtype
106126 device = model .device
0 commit comments