Skip to content

Commit 2dbff48

Browse files
authored
feat: Allow for MoE kernels (Scatter MoE) irrespective of use of EP (#136)
* feat: no ep yes kernels Signed-off-by: Mehant Kammakomati <[email protected]> * feat: support low_cpu_mem Signed-off-by: Mehant Kammakomati <[email protected]> * fix: review comments Signed-off-by: Mehant Kammakomati <[email protected]> * fix: review comments Signed-off-by: Mehant Kammakomati <[email protected]> * fix: address review comments Signed-off-by: Mehant Kammakomati <[email protected]> * fix: lint error for peft Signed-off-by: Mehant Kammakomati <[email protected]> * fix: review comment from Fabian: Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent ee7d713 commit 2dbff48

File tree

4 files changed

+59
-23
lines changed

4 files changed

+59
-23
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,30 @@ def __init__(self, configurations: Dict[str, Dict]):
4242
super().__init__(configurations)
4343

4444
# ep_degree determines the expert parallel sharding
45-
# - default of 1 means experts are not sharded and operate in pure replication.
45+
# If disable_distributed==False, the moe plugin handles sharding / replication,
46+
# otherwise user will need handle this manually (e.g., using FSDP)
47+
#
48+
# ep_degree=1 (default):
49+
# - disable_distributed=False (default) means
50+
# experts are replicated while using ScatterMoE kernels.
51+
# - disable_distributed=True means no replication (please use
52+
# own training framework)
53+
#
54+
# ep_degree > 1:
55+
# - disabled_distributed=False (default) means expert sharding with
56+
# Scatter MoE Kernels.
57+
# disable_distributed=True cannot be set in this case; errors out.
58+
4659
self._ep_degree = self._check_config_and_maybe_check_values(
4760
key="training.moe.scattermoe.ep_degree",
4861
default=1,
4962
)
5063

64+
self._disable_distributed = self._check_config_and_maybe_check_values(
65+
key="training.moe.scattermoe.disable_distributed",
66+
default=False,
67+
)
68+
5169
@property
5270
def requires_augmentation(self):
5371
return True
@@ -77,6 +95,7 @@ def augmentation(
7795
rank=rank,
7896
world_size=world_size,
7997
ep_degree=self._ep_degree,
98+
disable_distributed=self._disable_distributed,
8099
mixed_precision=False, # Currently this is hardcoded to OFF
81100
)
82101
return model, modifiable_args
@@ -91,23 +110,24 @@ def get_callbacks_and_ready_for_train(
91110
and getattr(accelerator.state, "fsdp_plugin", None) is not None
92111
):
93112

94-
# - use an internal function call to get the no split
95-
# module names, which are typically layers
96-
_layers = model._get_no_split_modules("")
97-
accelerator.state.fsdp_plugin.ignored_modules = [
98-
getattr(layer, name)
99-
for name in self._moe_component_module_names
100-
for layer in model.modules()
101-
if layer.__class__.__name__ in _layers
102-
]
103-
104-
# call this to patch the HF save and load functions to be able
105-
# to save DTensors propery
106-
patch_huggingface_save_and_load_for_dtensors()
107-
108-
# call this to patch torch optim to not use
109-
# foreach for dtensors
110-
patch_torch_optim_foreach_to_not_apply_to_dtensors()
113+
if not self._disable_distributed:
114+
# - use an internal function call to get the no split
115+
# module names, which are typically layers
116+
_layers = model._get_no_split_modules("")
117+
accelerator.state.fsdp_plugin.ignored_modules = [
118+
getattr(layer, name)
119+
for name in self._moe_component_module_names
120+
for layer in model.modules()
121+
if layer.__class__.__name__ in _layers
122+
]
123+
124+
# call this to patch the HF save and load functions to be able
125+
# to save DTensors propery
126+
patch_huggingface_save_and_load_for_dtensors()
127+
128+
# call this to patch torch optim to not use
129+
# foreach for dtensors
130+
patch_torch_optim_foreach_to_not_apply_to_dtensors()
111131

112132
return callbacks
113133

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def prepare_scattermoe(
104104
rank: int = None,
105105
world_size: int = None,
106106
ep_degree: int = 1,
107+
disable_distributed: bool = False,
107108
key_rep: str = KEY_REPLICATE,
108109
key_ep: str = KEY_EXPERT_PARALLEL,
109110
device_type: str = "cuda",
@@ -116,6 +117,12 @@ def prepare_scattermoe(
116117
# pylint: disable=import-outside-toplevel
117118
from .scattermoe import ScatterMoE
118119

120+
if disable_distributed and ep_degree > 1:
121+
raise ValueError(
122+
"expert sharding can not be deferred to top level sharding"
123+
"protocol (e.g. FSDP) when ep_degree > 1"
124+
)
125+
119126
assert world_size % ep_degree == 0, (
120127
f"world size ({world_size}) " f"not divisible by ep_size ({ep_degree})."
121128
)
@@ -129,6 +136,9 @@ def prepare_scattermoe(
129136
# current rank of the device
130137
device = torch.device(f"{device_type}:{rank}")
131138

139+
if ep_degree == 1 and disable_distributed and is_fsdp_enabled() and rank == 0:
140+
device = torch.device("cpu")
141+
132142
# get the scattermoe conversion spec
133143
(
134144
moe_cls,
@@ -142,7 +152,8 @@ def prepare_scattermoe(
142152
expert_name = expert_name.split("|")
143153

144154
rep_size = world_size // ep_degree
145-
if ep_degree == 1 and rep_size == 1:
155+
156+
if ep_degree == 1:
146157
# in this case no need for sharding
147158
device_mesh = None
148159
elif rep_size == 1:
@@ -265,7 +276,10 @@ def prepare_scattermoe(
265276
)
266277

267278
if device_mesh is None:
268-
_init_scattermoe_context = nullcontext
279+
if not is_fsdp_enabled() or is_local_dist_rank_0():
280+
_init_scattermoe_context = nullcontext
281+
else:
282+
_init_scattermoe_context = init_empty_weights
269283
else:
270284
# in this case we need to distribute parameters, so just initialize
271285
# the scattermoe module swap with empty weights,
@@ -318,8 +332,9 @@ def prepare_scattermoe(
318332
if device_mesh is None:
319333
# - if not on meta, just load the state dict
320334
# - and then put on the device
321-
moe.load_state_dict(sd)
322-
moe = moe.to(device)
335+
if not is_fsdp_enabled() or is_local_dist_rank_0():
336+
moe.load_state_dict(sd)
337+
moe = moe.to(device)
323338
else:
324339
# - otherwise, we need to distribtue and will
325340
# replace the parameters

plugins/accelerated-peft/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ bitsandbytes >=0.41,<=0.43.3
1616
threadpoolctl >= 3.5.0
1717

1818
datasets >= 2.20.0
19+

plugins/framework/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ classifiers=[
2424
dependencies = [
2525
"numpy<2.0", # numpy needs to be bounded due to incompatiblity with current torch<2.3
2626
"torch>2.2",
27-
"peft",
27+
"peft<=0.14.0", # QuantLinear is not available for peft version > 0.14.0
2828
"accelerate",
2929
"pandas",
3030
]

0 commit comments

Comments
 (0)