Skip to content

Commit 1d6a42b

Browse files
committed
more cleanup
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
1 parent 97a0bb4 commit 1d6a42b

15 files changed

+156
-327
lines changed

plugins/accelerated-moe/README.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,25 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks
1616
Our `ScatterMoe` implementation is a module-swap; to add new models we need to update the specifications in [scattermoe_constants.py](./src/fms_acceleration_moe/utils/scattermoe_constants.py).
1717
- See the code documentation within to understand how to add new models.
1818

19-
### Code Extracted from Megablocks
19+
### Conversion of ScatterMoE
20+
21+
`ScatterMoE` checkpoints are saved using `torch.distributed.checkpoint` (DCP) and which is by default `StateDictType.SHARDED_STATE_DICT`:
22+
- `DTensors` limited support for full state dicts.
23+
- sharded state dicts are the extremely efficient, and require little comms overhead when saving.
24+
25+
We provide a script to recover back the original checkpoint:
26+
- currently the script be used if the DSP saves a single `pytorch_model_fsdp_0` folder
27+
- say the checkpoint is stored in `hf/checkpoint-10`, then call
28+
29+
```
30+
python -m fms_acceleration_moe.utils.checkpoint_utils \
31+
hf/checkpoint-10/pytorch_model_fsdp_0 \
32+
output_dir mistralai/Mixtral-8x7B-Instruct-v0.1
33+
```
34+
35+
36+
37+
## Code Extracted from Megablocks
2038
2139
Notes on code extraction:
2240
- we have only extracted two `autograd` functions [GatherOp](https://github.com/databricks/megablocks/blob/main/megablocks/ops/gather.py) and [ScatterOp](https://github.com/databricks/megablocks/blob/main/megablocks/ops/scatter.py),
@@ -71,6 +89,5 @@ These are currently some known issues not yet resolved:
7189
- The design currently does a swap for the mixture-of-expert module with [ScatterMoE](./src/fms_acceleration_moe/utils/scattermoe.py). This affects the `state_dict` of the model, so any saved checkpoint may need to be converted back to original.
7290
- should eventually remove the dependency on an external `kernel-hyperdrive` repository.
7391
- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed.
74-
- currently only supports `StateDictType.SHARDED_STATE_DICT` because the implementation uses `DTensors` which have limited support for full state dicts. However for efficiency considerations, sharded state dicts are the most efficient.
7592
7693

plugins/accelerated-moe/configs/scattermoe.yaml

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,11 @@ training:
66
# expert-parallel for MoE
77
scattermoe:
88

9-
# TODO: should we even get rid of this?
10-
# The name of the mixture-of-experts class
11-
# moe_component_class: MixtralSparseMoeBlock
12-
# moe_component_class: GraniteMoeMoE
13-
14-
# The module name of the router in moe_component_class above
15-
# moe_gate_module_name: gate
16-
17-
# The module name of the experts in moe_component_class above
18-
# moe_experts_module_name: experts
19-
20-
# the mlp version
21-
# - for those with only up and down projs, use "v1"
22-
# - for those with only up, down and gate projs, use "v2"
23-
# moe_mlp_impl: v2
24-
25-
# if True, then we shard experts across data parallel dimension
26-
# - only feasible if world_size divides the number of experts
27-
# shard_along_dp: true
28-
29-
# to be specified only if shard_along_dp == False. This will influence
30-
# the level of sharding, which indicates how many experts per device
31-
# - the number of experts per device will be num_experts / ep_size
32-
# - we disable the ability to set ep_size=1 since this means no sharding
33-
# - NOTE: ep_size=1 does not mean shard_along_dp=True, which would otherwise
34-
# be contradictory since ep_size suggests no expert parallel.
35-
ep_degree: 1
36-
37-
# the MoE dropless implementation. Currently we only support "dropless_sparse", but
38-
# in the future we may support others
39-
# moe_implementation: dropless_sparse
40-
41-
# for load_balancing_loss
42-
# load_balancing_loss: false
9+
# The level of expert parallel sharding.
10+
# - 1 means no sharding
11+
# - if > 1, please ensure that this divides the world_size. This is because
12+
# the devices will be replicated for every ep_degree devices, and
13+
# the experts will be sharded within each group.
14+
# - if > 1, also ensure that it divides the number of experts, as each device
15+
# will then have num_of_experts / ep_degree experts.
16+
ep_degree: 1

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@
3131
# pylint: disable=too-many-instance-attributes
3232
class ScatterMoEAccelerationPlugin(AccelerationPlugin):
3333

34-
# NOTE: its not packaged properly so, "importlib.util.find_spec('khd')"
35-
# returns but "importlib.metadata.version('kernel-hyperdrive') is needed"
36-
# require_packages = {"khd"}
37-
# NOTE: will address this later if we remove the dependency on kernel-hyperdrive
34+
# NOTE: we cannot do
35+
# - require_packages = {"khd"}
36+
# this is because the khd fork is not properly packaged as a PyPI project, and so
37+
# - "importlib.util.find_spec('khd')" returns, but
38+
# - "importlib.metadata.version('kernel-hyperdrive')" does not return
39+
# if we decide to extract the kernels, then we do not need to anymore,
40+
# https://github.com/foundation-model-stack/fms-acceleration/issues/105
41+
3842
restricted_model_archs = ["GraniteMoeForCausalLM", "MixtralForCausalLM"]
3943

4044
def __init__(self, configurations: Dict[str, Dict]):
@@ -75,6 +79,7 @@ def model_loader(self, model_name: str, **kwargs):
7579
# NOTE: there is currently no good way to get the mixed precision
7680
# flag from train_args. It will be better to handle this if
7781
# when we move the sharding to augmentation.
82+
# https://github.com/foundation-model-stack/fms-acceleration/issues/103
7883

7984
return model
8085

plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818

1919
# this is a special patch function to disable foreach for
2020
# dtensors, which has been introduced since torch 2.4.
21-
# The reason is because this will cause problems in the optimizer
22-
# lerp.
23-
21+
# The reason is because this will cause problems in the optimizer
22+
# RuntimeError: aten._foreach_mul_.Scalar: got mixed torch.Tensor and DTensor,
23+
# need to convert all torch.Tensor to DTensor before calling distributed operators!
2424

25+
# - this function patches torch
2526
def patch_torch_optim_foreach_to_not_apply_to_dtensors():
2627
# guarded.
2728
# this is an array of supported types, we will remove

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,19 @@
4646
# - variable to capture the model variable
4747
# in the save/load model calls
4848
MODEL_INDEX = None
49+
KEY_MODEL = 'model'
50+
KEY_OPTIMIZER = 'optimizer'
4951

50-
# Below are rewrite of functions to be able to handle dtensors
51-
52+
# Below are rewrite of HF FSDP model saving functions to be able to handle
53+
# that the parameters are now a mixture of regular and Dtensors.
54+
# - these functions are found in accelerate.utils.fsdp_utils.py
55+
# - save_fsdp_model, save_fsdp_optimizer, load_fsdp_model, load_fsdp_optimizer
56+
# NOTE: we will observe warnings such as
57+
# /torch/distributed/checkpoint/state_dict.py:520:
58+
# FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
5259

5360
# rewrite of func from accelerate.utils.fsdp_utils.py
54-
# - empty function, as main logic is in the optimizer call
55-
# save_fsdp_optimizer (see below).
61+
# - empty function, the main logic will be in save_fsdp_optimizer (see below).
5662
def save_fsdp_model(
5763
fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False
5864
):
@@ -62,7 +68,7 @@ def save_fsdp_model(
6268

6369

6470
# rewrite of func from accelerate.utils.fsdp_utils.py
65-
# - saves both model and optimizer
71+
# - saves both model and optimizer at the same time
6672
def save_fsdp_optimizer(
6773
fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0
6874
):
@@ -80,7 +86,7 @@ def save_fsdp_optimizer(
8086
os.makedirs(ckpt_model, exist_ok=True)
8187
logger.info(f"Saving model to {ckpt_model}")
8288
dcp.save(
83-
state_dict={"model": model_state_dict},
89+
state_dict={KEY_MODEL: model_state_dict},
8490
storage_writer=dcp.FileSystemWriter(ckpt_model),
8591
planner=DefaultSavePlanner(),
8692
)
@@ -91,16 +97,15 @@ def save_fsdp_optimizer(
9197
os.makedirs(ckpt_opt, exist_ok=True)
9298
logger.info(f"Saving Optimizer state to {ckpt_opt}")
9399
dcp.save(
94-
state_dict={"optimizer": optimizer_state_dict},
100+
state_dict={KEY_OPTIMIZER: optimizer_state_dict},
95101
storage_writer=dcp.FileSystemWriter(ckpt_opt),
96102
planner=DefaultSavePlanner(),
97103
)
98104
logger.info(f"Optimizer state saved in {ckpt_opt}")
99105

100106

101107
# rewrite of func from accelerate.utils.fsdp_utils.py
102-
# - empty function, as main logic is in the optimizer call
103-
# load_fsdp_optimizer (see below).
108+
# - empty function, main logic in load_fsdp_optimizer (see below).
104109
def load_fsdp_model(
105110
fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False
106111
):
@@ -133,15 +138,15 @@ def load_fsdp_optimizer(
133138
# - load the model state dict
134139
ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
135140
dcp.load(
136-
state_dict={"model": model_state_dict},
141+
state_dict={KEY_MODEL: model_state_dict},
137142
storage_reader=dcp.FileSystemReader(ckpt_model),
138143
planner=DefaultLoadPlanner(),
139144
)
140145

141146
# - load the optimizer state dict
142147
ckpt_opt = os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
143148
dcp.load(
144-
state_dict={"optimizer": optimizer_state_dict},
149+
state_dict={KEY_OPTIMIZER: optimizer_state_dict},
145150
storage_reader=dcp.FileSystemReader(ckpt_opt),
146151
planner=DefaultLoadPlanner(),
147152
)
@@ -154,10 +159,15 @@ def load_fsdp_optimizer(
154159
optim_state_dict=optimizer_state_dict,
155160
)
156161

157-
# HACK for now
158-
# - if seems that if params is empty, then the loading has someo
159-
# problems
160-
# - so for now, we just dump some random defaults
162+
# FIXME:
163+
# - We see errors that occur in optimizer.step()
164+
# - torch/optim/optimizer.py", line 89, in _use_grad
165+
# - torch/optim/adamw.py", line 214, in step beta1, beta2 = cast(Tuple[float, float], group["betas"])
166+
# - KeyError: 'betas'
167+
# - Fortunately, this seems to be limited to the empty groups case, where
168+
# it seems that it is just the params are not initialized. Since we suppose
169+
# these groups are never used, we simply initialize the empty groups with
170+
# random values so the errors do not throw.
161171
for group in optimizer.param_groups:
162172
if len(group["params"]) == 0:
163173
group["betas"] = (0.9, 0.999)
@@ -182,8 +192,8 @@ def patch_huggingface_save_and_load_for_dtensors():
182192
patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer)
183193

184194

185-
# trick to get the resolved cache file to acccess the safetensor
186-
# NOTE: this does not work if _dict_from_json_file, like GGUF files
195+
# this function implements a trick to get the resolved cache file to acccess the safetensor
196+
# - NOTE: does not work if _dict_from_json_file is not called, such as in the case of GGUF files.
187197
def get_resolved_checkpoint_location(model_name_or_path: str):
188198

189199
result = None
@@ -201,14 +211,17 @@ def _dict_from_json_file(resolved_config_file):
201211
return os.path.dirname(result)
202212

203213

204-
def restore_scattermoe_checkpoint_to_orig(
214+
# function to get the ScatterMoE state dict from its DCP checkpoint
215+
# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
216+
# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
217+
# can restore the checkpoint to be loaded by the original architecture.
218+
def get_scattermoe_state_dict(
205219
dcp_checkpoint_dir: str,
206220
pretrained_model_name_or_path: str = None,
207-
dcp_outer_key: str = "model",
208221
):
209222
"""
210223
Parameters:
211-
dcp_checkpoint_dir (str): the dcp to be converted.
224+
dcp_checkpoint_dir (str): the DCP to be converted.
212225
pretrained_model_name_or_path (str): Optional, if provided we will
213226
use the hints to remap the
214227
"""
@@ -230,7 +243,7 @@ def restore_scattermoe_checkpoint_to_orig(
230243
planner=_EmptyStateDictLoadPlanner(),
231244
no_dist=True,
232245
)
233-
sd = sd[dcp_outer_key]
246+
sd = sd[KEY_MODEL]
234247

235248
# if not provided
236249
if pretrained_model_name_or_path is None:
@@ -401,6 +414,16 @@ def _infer_prefixes_and_module_names(
401414
)
402415
)
403416

417+
parser.add_argument(
418+
"dcp_checkpoint_dir",
419+
help="Path to the distributed checkpoint.",
420+
)
421+
422+
parser.add_argument(
423+
"output_dir",
424+
help="Path to the location to write the converted checkpoint."
425+
)
426+
404427
parser.add_argument(
405428
"pretrained_model_name_or_path",
406429
help=(
@@ -409,3 +432,14 @@ def _infer_prefixes_and_module_names(
409432
"checkpoint is obtained)."
410433
),
411434
)
435+
436+
args = parser.parse_args()
437+
438+
# get the converted statedict
439+
sd = get_scattermoe_state_dict(
440+
args.dcp_checkpoint_dir,
441+
args.pretrained_model_name_or_path
442+
)
443+
444+
# save it
445+
torch.save(sd, args.output_dir)

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
) from e
4040

4141
# Local
42-
from .scattermoe_constants import SCATTERMOE_SPEC_HAS_GATE_WEIGHT
42+
from .scattermoe_constants import SCATTERMOE_SPEC_HAS_GATE
4343
from .scattermoe_utils import all_to_all_gather_inputs, scatter_with_routing_weights
4444

4545

@@ -306,7 +306,7 @@ def __init__(
306306
device=device,
307307
lora_config=lora_config,
308308
)
309-
if mlp_arch == SCATTERMOE_SPEC_HAS_GATE_WEIGHT:
309+
if mlp_arch == SCATTERMOE_SPEC_HAS_GATE:
310310
self.w3 = ScatteredExperts(
311311
in_features=self.hidden_size,
312312
out_features=self.intermediate_size,

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# Currently out ScatterMoE drop supports an up/down proj, and
3030
# and optional gate_proj.
3131
# - When new architectures are supported this list will update
32-
SCATTERMOE_SPEC_HAS_GATE_WEIGHT = "has_gate_proj"
32+
SCATTERMOE_SPEC_HAS_GATE = "Gated"
3333

3434
# - moe_cls
3535
# - router_name
@@ -66,21 +66,19 @@
6666
"MixtralSparseMoeBlock",
6767
"gate",
6868
"experts",
69-
SCATTERMOE_SPEC_HAS_GATE_WEIGHT,
69+
SCATTERMOE_SPEC_HAS_GATE,
7070
True,
7171
),
7272
"GraniteMoeForCausalLM": (
7373
"GraniteMoeMoE",
7474
"router",
7575
"input_linear|output_linear|input_linear",
76-
SCATTERMOE_SPEC_HAS_GATE_WEIGHT,
76+
SCATTERMOE_SPEC_HAS_GATE,
7777
False,
7878
),
7979
}
8080

8181
# helper function to get the spec based on architectures
82-
83-
8482
def get_scattermoe_conv_spec_from_archs(architectures: List[str]):
8583
# infer the spec
8684
for archs, spec in SCATTERMOE_CONVERSION_SPEC.items():

sample-configurations/moe-scattermoe-granite-ep1-padding-free-sample-configuration.yaml

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,36 +18,11 @@ plugins:
1818
# expert-parallel for MoE
1919
scattermoe:
2020

21-
# TODO: should we even get rid of this?
22-
# The name of the mixture-of-experts class
23-
# moe_component_class: MixtralSparseMoeBlock
24-
# moe_component_class: GraniteMoeMoE
25-
26-
# The module name of the router in moe_component_class above
27-
# moe_gate_module_name: gate
28-
29-
# The module name of the experts in moe_component_class above
30-
# moe_experts_module_name: experts
31-
32-
# the mlp version
33-
# - for those with only up and down projs, use "v1"
34-
# - for those with only up, down and gate projs, use "v2"
35-
# moe_mlp_impl: v2
36-
# if True, then we shard experts across data parallel dimension
37-
# - only feasible if world_size divides the number of experts
38-
# shard_along_dp: true
39-
40-
# to be specified only if shard_along_dp == False. This will influence
41-
# the level of sharding, which indicates how many experts per device
42-
# - the number of experts per device will be num_experts / ep_size
43-
# - we disable the ability to set ep_size=1 since this means no sharding
44-
# - NOTE: ep_size=1 does not mean shard_along_dp=True, which would otherwise
45-
# be contradictory since ep_size suggests no expert parallel.
21+
# The level of expert parallel sharding.
22+
# - 1 means no sharding
23+
# - if > 1, please ensure that this divides the world_size. This is because
24+
# the devices will be replicated for every ep_degree devices, and
25+
# the experts will be sharded within each group.
26+
# - if > 1, also ensure that it divides the number of experts, as each device
27+
# will then have num_of_experts / ep_degree experts.
4628
ep_degree: 1
47-
48-
# the MoE dropless implementation. Currently we only support "dropless_sparse", but
49-
# in the future we may support others
50-
# moe_implementation: dropless_sparse
51-
52-
# for load_balancing_loss
53-
# load_balancing_loss: false

0 commit comments

Comments
 (0)