Skip to content

Commit 46ad230

Browse files
committed
fix: address review comments
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 6c62eda commit 46ad230

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,31 @@ def __init__(self, configurations: Dict[str, Dict]):
4242
super().__init__(configurations)
4343

4444
# ep_degree determines the expert parallel sharding
45-
# - default of 1 means experts are not sharded and operate in pure replication.
45+
# If disable_distributed is False, expert sharding is handled
46+
# by the plugin else deferred to top-level distribution (e.g. FSDP).
47+
#
48+
# default of 1 for ep_degree and False for disable_distributed
49+
# mean experts are not sharded and operate in pure replication with
50+
# Scatter MoE kernels.
51+
#
52+
# ep_degree==1 and disable_distributed is True mean use of Scatter MoE
53+
# kernels + distribution deferred to top level distribution protocol (e.g. FSDP).
54+
#
55+
# ep_degree>1 and disabled_distributed is False mean enabling expert parallel
56+
# and Scatter MoE Kernels.
57+
#
58+
# ep_degree>1 and disable_distributed is True errors out.
59+
4660
self._ep_degree = self._check_config_and_maybe_check_values(
4761
key="training.moe.scattermoe.ep_degree",
4862
default=1,
4963
)
5064

65+
self._disable_distributed = self._check_config_and_maybe_check_values(
66+
key="training.moe.scattermoe.disable_distributed",
67+
default=False,
68+
)
69+
5170
@property
5271
def requires_augmentation(self):
5372
return True
@@ -77,6 +96,7 @@ def augmentation(
7796
rank=rank,
7897
world_size=world_size,
7998
ep_degree=self._ep_degree,
99+
disable_distributed=self._disable_distributed,
80100
mixed_precision=False, # Currently this is hardcoded to OFF
81101
)
82102
return model, modifiable_args
@@ -91,8 +111,7 @@ def get_callbacks_and_ready_for_train(
91111
and getattr(accelerator.state, "fsdp_plugin", None) is not None
92112
):
93113

94-
# When EP is not enabled we want to shard the experts using FSDP
95-
if self._ep_degree != 0:
114+
if not self._disable_distributed:
96115
# - use an internal function call to get the no split
97116
# module names, which are typically layers
98117
_layers = model._get_no_split_modules("")

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def prepare_scattermoe(
104104
rank: int = None,
105105
world_size: int = None,
106106
ep_degree: int = 1,
107+
disable_distributed: bool = False,
107108
key_rep: str = KEY_REPLICATE,
108109
key_ep: str = KEY_EXPERT_PARALLEL,
109110
device_type: str = "cuda",
@@ -116,13 +117,10 @@ def prepare_scattermoe(
116117
# pylint: disable=import-outside-toplevel
117118
from .scattermoe import ScatterMoE
118119

119-
ep_disabled = False
120-
if ep_degree == 0:
121-
ep_disabled = True
122-
# flow of code when EP not enabled is mostly same as
123-
# with ep_degree set to 1. Therefore, we explicitly set
124-
# ep_degree to 1 however handle it along with ep_disabled var
125-
ep_degree = 1
120+
if disable_distributed and ep_degree > 1:
121+
raise ValueError(
122+
"expert sharding can not be deferred to top level sharding protocol (e.g. FSDP) when ep_degree > 1"
123+
)
126124

127125
assert world_size % ep_degree == 0, (
128126
f"world size ({world_size}) " f"not divisible by ep_size ({ep_degree})."
@@ -137,11 +135,7 @@ def prepare_scattermoe(
137135
# current rank of the device
138136
device = torch.device(f"{device_type}:{rank}")
139137

140-
if ep_disabled:
141-
# Larger models result in OOM especially when loading
142-
# all experts to the same GPU device (when EP disabled).
143-
# For cases like FSDP + EP disabled, its memory efficient to
144-
# load the model to CPU and hand it over to the FSDP.
138+
if ep_degree == 1 and disable_distributed and is_fsdp_enabled() and rank == 0:
145139
device = torch.device("cpu")
146140

147141
# get the scattermoe conversion spec
@@ -158,7 +152,7 @@ def prepare_scattermoe(
158152

159153
rep_size = world_size // ep_degree
160154

161-
if ep_degree == 1 and (rep_size == 1 or ep_disabled):
155+
if ep_degree == 1:
162156
# in this case no need for sharding
163157
device_mesh = None
164158
elif rep_size == 1:
@@ -281,10 +275,10 @@ def prepare_scattermoe(
281275
)
282276

283277
if device_mesh is None:
284-
if is_fsdp_enabled() and rank > 0:
285-
_init_scattermoe_context = init_empty_weights
286-
else:
278+
if not is_fsdp_enabled() or is_local_dist_rank_0():
287279
_init_scattermoe_context = nullcontext
280+
else:
281+
_init_scattermoe_context = init_empty_weights
288282
else:
289283
# in this case we need to distribute parameters, so just initialize
290284
# the scattermoe module swap with empty weights,
@@ -337,7 +331,7 @@ def prepare_scattermoe(
337331
if device_mesh is None:
338332
# - if not on meta, just load the state dict
339333
# - and then put on the device
340-
if rank == 0 or not is_fsdp_enabled():
334+
if not is_fsdp_enabled() or is_local_dist_rank_0():
341335
moe.load_state_dict(sd)
342336
moe = moe.to(device)
343337
else:

0 commit comments

Comments
 (0)