Skip to content

Commit dbaf61a

Browse files
committed
nit: lint and fmt
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 2f192bd commit dbaf61a

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

.github/workflows/build-and-publish.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ jobs:
1717
- "attention-and-distributed-packing"
1818
- "accelerated-moe"
1919
- "online-data-mixing"
20+
- "mamba-cp"
2021

2122
permissions:
2223
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing

.github/workflows/format.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ jobs:
3232
- "attention-and-distributed-packing"
3333
- "accelerated-moe"
3434
- "online-data-mixing"
35+
- "mamba-cp"
3536

3637
steps:
3738
- name: Delete huge unnecessary tools folder

plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,28 @@
1414
# Standard
1515
from typing import Dict
1616

17-
# Third Party
1817
try:
18+
# Third Party
1919
from mamba_ssm.modules.mamba2_cp import Mamba2CP
20-
except ImportError:
21-
raise ValueError("custom mamba_ssm package installation is needed"
22-
"install from https://github.com/garrett361/mamba/tree/mamba-cp"
23-
)
20+
except ImportError as exc:
21+
raise ValueError(
22+
"custom mamba_ssm package installation is needed"
23+
"install from https://github.com/garrett361/mamba/tree/mamba-cp"
24+
) from exc
25+
# Third Party
2426
from accelerate.logging import get_logger
2527

2628
# pylint: disable=import-error
2729
from torch.distributed._tensor.device_mesh import init_device_mesh
28-
from tqdm import tqdm
29-
from transformers.modeling_utils import is_fsdp_enabled
30-
import torch
3130

3231
# to avoid rechunking/sharding of the buffers
3332
# ideally this is not optimal
3433
# this is done to make self attention cp compatible with mamba cp
3534
from torch.distributed.tensor.experimental._attention import _cp_options
35+
from tqdm import tqdm
36+
from transformers.modeling_utils import is_fsdp_enabled
37+
import torch
38+
3639
_cp_options.enable_load_balance = False
3740

3841
logger = get_logger(__name__)
@@ -42,7 +45,8 @@
4245
key_cp = "cp"
4346
key_rep = "dp_shard"
4447

45-
# extract ssm config from hf config to be used
48+
49+
# extract ssm config from hf config to be used
4650
# while swapping the mamba modules
4751
def get_ssmconfig_from_hfconfig(hf_config) -> Dict:
4852
config_ssm = {}
@@ -75,6 +79,7 @@ def forward(
7579
inference_params=None,
7680
)
7781

82+
7883
# patches each mamba module with mamba cp module
7984
# mamba cp module's weights are exactly same as hf mamba module
8085
# so we reuse the state dict and the same does not need special handling
@@ -94,7 +99,7 @@ def patch_mamba_layers_with_cp_head(
9499
if is_fsdp_enabled():
95100
device = torch.device("cpu")
96101
rep_size = world_size // cp_degree
97-
102+
98103
# auto infer ddp and cp ranks
99104
# does not work on other combination of parallelisms
100105
logger.warning(

0 commit comments

Comments
 (0)