1414# Standard
1515from typing import Dict
1616
17- # Third Party
1817try :
18+ # Third Party
1919 from mamba_ssm .modules .mamba2_cp import Mamba2CP
20- except ImportError :
21- raise ValueError ("custom mamba_ssm package installation is needed"
22- "install from https://github.com/garrett361/mamba/tree/mamba-cp"
23- )
20+ except ImportError as exc :
21+ raise ValueError (
22+ "custom mamba_ssm package installation is needed"
23+ "install from https://github.com/garrett361/mamba/tree/mamba-cp"
24+ ) from exc
25+ # Third Party
2426from accelerate .logging import get_logger
2527
2628# pylint: disable=import-error
2729from torch .distributed ._tensor .device_mesh import init_device_mesh
28- from tqdm import tqdm
29- from transformers .modeling_utils import is_fsdp_enabled
30- import torch
3130
3231# to avoid rechunking/sharding of the buffers
3332# ideally this is not optimal
3433# this is done to make self attention cp compatible with mamba cp
3534from torch .distributed .tensor .experimental ._attention import _cp_options
35+ from tqdm import tqdm
36+ from transformers .modeling_utils import is_fsdp_enabled
37+ import torch
38+
3639_cp_options .enable_load_balance = False
3740
3841logger = get_logger (__name__ )
4245key_cp = "cp"
4346key_rep = "dp_shard"
4447
45- # extract ssm config from hf config to be used
48+
49+ # extract ssm config from hf config to be used
4650# while swapping the mamba modules
4751def get_ssmconfig_from_hfconfig (hf_config ) -> Dict :
4852 config_ssm = {}
@@ -75,6 +79,7 @@ def forward(
7579 inference_params = None ,
7680 )
7781
82+
7883# patches each mamba module with mamba cp module
7984# mamba cp module's weights are exactly same as hf mamba module
8085# so we reuse the state dict and the same does not need special handling
@@ -94,7 +99,7 @@ def patch_mamba_layers_with_cp_head(
9499 if is_fsdp_enabled ():
95100 device = torch .device ("cpu" )
96101 rep_size = world_size // cp_degree
97-
102+
98103 # auto infer ddp and cp ranks
99104 # does not work on other combination of parallelisms
100105 logger .warning (
0 commit comments