Skip to content

Commit f83cc8d

Browse files
committed
docs: add docs
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 1990451 commit f83cc8d

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

plugins/mamba-cp/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,16 @@
11
# Context Parallel for Mamba Kernels
2+
3+
This library contains plugin for applying context parallelism for mamba module (mamba_ssm).
4+
5+
## Plugins
6+
7+
Plugin | Description | Depends | Loading | Augmentation | Callbacks
8+
--|--|--|--|--|--
9+
[mcp](./src/fms_acceleration_mcp/framework_plugin_mcp.py) | context parallel for mamba | [custom mamba cp implementation](https://github.com/garrett361/mamba/tree/mamba-cp) | ✅ | ✅ | ✅
10+
11+
## Mamba CP Implementation
12+
13+
Context parallel implementation is taken from a custom [mamba_ssm repo](https://github.com/garrett361/mamba/tree/mamba-cp) with cp implemenation. Thus, its required this repo is installed to use this plugin.
14+
15+
## Known Issues
16+
1. load balancing is removed given limited support on mamba cp implementation. This could lead to potential throughput drops for trainings using causal mask.

plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
from typing import Dict
1616

1717
# Third Party
18-
from mamba_ssm.modules.mamba2_cp import Mamba2CP
18+
try:
19+
from mamba_ssm.modules.mamba2_cp import Mamba2CP
20+
except ImportError:
21+
raise ValueError("custom mamba_ssm package installation is needed"
22+
"install from https://github.com/garrett361/mamba/tree/mamba-cp"
23+
)
24+
from accelerate.logging import get_logger
1925

2026
# pylint: disable=import-error
2127
from torch.distributed._tensor.device_mesh import init_device_mesh
@@ -25,15 +31,20 @@
2531

2632
# to avoid rechunking/sharding of the buffers
2733
# ideally this is not optimal
34+
# this is done to make self attention cp compatible with mamba cp
2835
from torch.distributed.tensor.experimental._attention import _cp_options
2936
_cp_options.enable_load_balance = False
3037

38+
logger = get_logger(__name__)
3139

40+
# the same keys are used in accelerate
41+
# therefore we choose these to be in sync and cross leverage.
3242
key_cp = "cp"
3343
key_rep = "dp_shard"
3444

35-
36-
def hf_config_ssm_config(hf_config) -> Dict:
45+
# extract ssm config from hf config to be used
46+
# while swapping the mamba modules
47+
def get_ssmconfig_from_hfconfig(hf_config) -> Dict:
3748
config_ssm = {}
3849
config_ssm["d_model"] = hf_config.hidden_size
3950
config_ssm["d_state"] = 128
@@ -45,6 +56,7 @@ def hf_config_ssm_config(hf_config) -> Dict:
4556
return config_ssm
4657

4758

59+
# to patch input arguments between mamba cp module and standard hf mamba module
4860
class Mamba2CPHF(Mamba2CP):
4961
def forward(
5062
self,
@@ -63,7 +75,10 @@ def forward(
6375
inference_params=None,
6476
)
6577

66-
78+
# patches each mamba module with mamba cp module
79+
# mamba cp module's weights are exactly same as hf mamba module
80+
# so we reuse the state dict and the same does not need special handling
81+
# while checkpointing.
6782
def patch_mamba_layers_with_cp_head(
6883
model,
6984
checkpoint_name_or_path,
@@ -74,12 +89,18 @@ def patch_mamba_layers_with_cp_head(
7489
cp_mamba_recompute,
7590
):
7691

77-
config_ssm = hf_config_ssm_config(model.config)
92+
config_ssm = get_ssmconfig_from_hfconfig(model.config)
7893
device = torch.device(f"cuda:{rank}")
7994
if is_fsdp_enabled():
8095
device = torch.device("cpu")
8196
rep_size = world_size // cp_degree
82-
97+
98+
# auto infer ddp and cp ranks
99+
# does not work on other combination of parallelisms
100+
logger.warning(
101+
"Mamba CP is only meant for parallelism combinations having DP and CP"
102+
"other combinations can lead to unexpected behaviour"
103+
)
83104
if cp_degree == 1:
84105
raise ValueError("CP degree can't be one")
85106
if rep_size == 1:
@@ -100,7 +121,6 @@ def patch_mamba_layers_with_cp_head(
100121
"cp_mamba_impl": cp_mamba_impl,
101122
"cp_mamba_recompute": cp_mamba_recompute,
102123
}
103-
104124
with torch.no_grad():
105125
dtype = model.dtype
106126
device = model.device

0 commit comments

Comments
 (0)