Skip to content

Commit feaeaa5

Browse files
committed
partially address mixed precision
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
1 parent 12d8619 commit feaeaa5

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

plugins/accelerated-moe/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,16 @@ bash scripts/run_benchmarks.sh \
4444

4545
## Expert-Parallel MoE with Megablocks
4646

47+
Currently supports *mixed precision*. Will upcast the router and the sharded experts if turned on.
48+
- However this is hard-coded to off at the moment.
49+
- The FSDP mixed precision works independenly of the MoE one.
50+
4751
Not all of the features of `megablocks` are being incorporated; listing down some of the restrictions of the current integration:
4852
- currently not passing the data parallel `dp_mesh` to the `FSDP` constructor, so `FSDP` will always shard over the default process group (over world_size).
4953
- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed.
5054
- only supports the *dropless sparse* MLPs in the megablocks package; the other variations like non-dropless and grouped computes are not currently integrated.
5155
- the `shard_moe` may not scale well with larger models as the current implementation `torch.concat` all the expert weights together before passing to `torch.distributed` to be sharded. This is redundently done in all devices, so it is inefficient.
5256
- currently only supports `StateDictType.SHARDED_STATE_DICT` because the implementation uses `DTensors` which have limited support for full state dicts. However for efficiency considerations, sharded state dicts are the most efficient.
53-
- currently may not support *mixed precision* properly; need to ascertain more clearly how the sharded `DTensors` are upcasted in the optimizer (if at all).
5457

5558
### Megablocks Dependencies
5659

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,13 @@ def model_loader(self, model_name: str, **kwargs):
151151
shared_mesh_dim=self._shard_along_dp,
152152
router_name=self._gate_module_name,
153153
expert_name=self._experts_module_name,
154+
mixed_precision=False, # Currently this is hardcoded to OFF
154155
)
156+
157+
# NOTE: there is currently no good way to get the mixed precision
158+
# flag from train_args. It will be better to handle this if
159+
# when we move the sharding to augmentation.
160+
155161
# NOTE: Currently, it is a bit troublesome to pass the device_mesh to
156162
# the FSDP constructor, so we do not do that.
157163
# - therefore FSDP will always shard on world_size over the default process

plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import json
2121
import os
2222
import re
23+
import warnings
2324

2425
# Third Party
2526
from accelerate import init_empty_weights
@@ -175,6 +176,7 @@ def load_sharded_experts_onto_device(
175176
device_mesh: DeviceMesh,
176177
placements: Placement,
177178
expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe
179+
mixed_precision: bool = False,
178180
):
179181
# typically they all should be same file, but to play safe, load the checkpoint file onto
180182
# cpu first since we may not need all weights in that file.
@@ -191,6 +193,7 @@ def load_sharded_experts_onto_device(
191193

192194
# go by one weight at a time.
193195
# - weight_name: points to megablocks.dmoe
196+
upcasted = set()
194197
for weight_name, vs in checkpoint_metadata.items():
195198
data = []
196199
for k, fi in vs:
@@ -204,11 +207,18 @@ def load_sharded_experts_onto_device(
204207
name = weight_name.split(".")
205208
path, name = ".".join(name[:-1]), name[-1]
206209
mod = dmoe.get_submodule(path)
207-
mod_dtype = getattr(mod, name).dtype
210+
211+
# if mixed_precision and KEY_DMOE_ROUTER not in weight_name:
212+
if mixed_precision:
213+
mod_dtype = torch.float32
214+
upcasted.add(weight_name)
215+
else:
216+
mod_dtype = getattr(mod, name).dtype
208217

209218
# the megablocks dmoe experts the expert features to be on DIM_EXPERT.
210219
# - concat on dim 0 and distribute
211220
# - cast to the correct dtype for the module
221+
# - if mixed precision is enabled, then sharded params are cased
212222
param = torch.concat(data, dim=DIM_EXPERT).to(mod_dtype)
213223

214224
_placements = placements
@@ -223,6 +233,9 @@ def load_sharded_experts_onto_device(
223233
# register the sharded parameter onto the megablocks.dmoe
224234
mod.register_parameter(name, param)
225235

236+
upcasted = ", ".join(sorted(upcasted))
237+
warnings.warn(f"Mixed precision turned on, upcasted MoE parameters: {upcasted}")
238+
226239

227240
def shard_moe(
228241
model: torch.nn.Module,
@@ -238,6 +251,7 @@ def shard_moe(
238251
expert_name: str = "experts",
239252
shared_mesh_dim: bool = True,
240253
ep_size: int = 1,
254+
mixed_precision: bool = False,
241255
):
242256
"""shard_moe takes a mixture-of-experts huggingface model and shards the experts
243257
on the current device. All layers layers that have a MoE module will be sharded.
@@ -272,6 +286,7 @@ def shard_moe(
272286
expert_name (str): module name of the experts in moe_cls (e.g., "experts").
273287
shared_mesh_dim (bool): for the sharding mode, see explanation above.
274288
ep_size (int): for shard_mesh_dim=False only, see explanation above.
289+
mixed_precision (bool): activate mixed precision and upcasts sharded params
275290
276291
"""
277292
# guarded import
@@ -389,7 +404,13 @@ def shard_moe(
389404
mp_dmoe = dmoe.dMoE(_args) # drop in replacement for now
390405

391406
load_sharded_experts_onto_device(
392-
mp_dmoe, loc, checkpoint_metadata, device_mesh, placements, expert_name
407+
mp_dmoe,
408+
loc,
409+
checkpoint_metadata,
410+
device_mesh,
411+
placements,
412+
expert_name,
413+
mixed_precision,
393414
)
394415
parent = model.get_submodule(prefix)
395416
setattr(parent, module_name, mp_dmoe)

0 commit comments

Comments
 (0)