Skip to content

Commit 486d41b

Browse files
committed
debug
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 28a20fe commit 486d41b

File tree

1 file changed

+10
-8
lines changed
  • plugins/mamba-cp/src/fms_acceleration_mcp/utils

1 file changed

+10
-8
lines changed

plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@
2323
from transformers.modeling_utils import is_fsdp_enabled
2424
import torch
2525

26-
key_ep = "cp"
26+
# to avoid rechunking/sharding of the buffers
27+
# ideally this is not optimal
28+
from torch.distributed.tensor.experimental._attention import _cp_options
29+
_cp_options.enable_load_balance = False
30+
31+
32+
key_cp = "cp"
2733
key_rep = "dp_shard"
2834

2935

@@ -67,10 +73,6 @@ def patch_mamba_layers_with_cp_head(
6773
cp_mamba_impl,
6874
cp_mamba_recompute,
6975
):
70-
# to avoid rechunking/sharding of the buffers
71-
# ideally this is not optimal
72-
from torch.distributed.tensor.experimental._attention import _cp_options
73-
_cp_options.enable_load_balance = False
7476

7577
config_ssm = hf_config_ssm_config(model.config)
7678
device = torch.device(f"cuda:{rank}")
@@ -84,17 +86,17 @@ def patch_mamba_layers_with_cp_head(
8486
device_mesh = init_device_mesh(
8587
"cuda",
8688
(cp_degree,),
87-
mesh_dim_names=(key_ep,),
89+
mesh_dim_names=(key_cp,),
8890
)
8991
else:
9092
device_mesh = init_device_mesh(
9193
"cuda",
9294
(rep_size, cp_degree),
93-
mesh_dim_names=(key_rep, key_ep),
95+
mesh_dim_names=(key_rep, key_cp),
9496
)
9597

9698
cp_args = {
97-
"cp_mesh": device_mesh[key_ep],
99+
"cp_mesh": device_mesh[key_cp],
98100
"cp_mamba_impl": cp_mamba_impl,
99101
"cp_mamba_recompute": cp_mamba_recompute,
100102
}

0 commit comments

Comments
 (0)