2323from transformers .modeling_utils import is_fsdp_enabled
2424import torch
2525
26- key_ep = "cp"
26+ # to avoid rechunking/sharding of the buffers
27+ # ideally this is not optimal
28+ from torch .distributed .tensor .experimental ._attention import _cp_options
29+ _cp_options .enable_load_balance = False
30+
31+
32+ key_cp = "cp"
2733key_rep = "dp_shard"
2834
2935
@@ -67,10 +73,6 @@ def patch_mamba_layers_with_cp_head(
6773 cp_mamba_impl ,
6874 cp_mamba_recompute ,
6975):
70- # to avoid rechunking/sharding of the buffers
71- # ideally this is not optimal
72- from torch .distributed .tensor .experimental ._attention import _cp_options
73- _cp_options .enable_load_balance = False
7476
7577 config_ssm = hf_config_ssm_config (model .config )
7678 device = torch .device (f"cuda:{ rank } " )
@@ -84,17 +86,17 @@ def patch_mamba_layers_with_cp_head(
8486 device_mesh = init_device_mesh (
8587 "cuda" ,
8688 (cp_degree ,),
87- mesh_dim_names = (key_ep ,),
89+ mesh_dim_names = (key_cp ,),
8890 )
8991 else :
9092 device_mesh = init_device_mesh (
9193 "cuda" ,
9294 (rep_size , cp_degree ),
93- mesh_dim_names = (key_rep , key_ep ),
95+ mesh_dim_names = (key_rep , key_cp ),
9496 )
9597
9698 cp_args = {
97- "cp_mesh" : device_mesh [key_ep ],
99+ "cp_mesh" : device_mesh [key_cp ],
98100 "cp_mamba_impl" : cp_mamba_impl ,
99101 "cp_mamba_recompute" : cp_mamba_recompute ,
100102 }
0 commit comments