Skip to content

Commit b45fdda

Browse files
committed
more cleanup
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent 6d1e646 commit b45fdda

File tree

4 files changed

+76
-43
lines changed

4 files changed

+76
-43
lines changed

plugins/accelerated-moe/README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Run the below in the top-level directory of this repo:
2020
tox -e run-benches \
2121
-x testenv:run-benches.deps+="-r plugins/accelerated-moe/requirements-khd.txt" \
2222
-- \
23-
"1 2 4 8" 128 benchmark_outputs scenarios-granite.yaml accelerated-moe-scatter
23+
"1 2 4 8" 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-scatter
2424
```
2525
or run the larger `Mixtral-8x7B` bench:
2626
```
@@ -43,10 +43,21 @@ bash scripts/run_benchmarks.sh \
4343
....
4444
```
4545

46+
4647
### Triton Kernel Dependencies
4748

4849
Currently we do not copy the `scattermoe` kernels into this respository, to this is an additional manual install:
4950

5051
```
5152
# this will install the kernel-hyperdrive fork with the scattermoe triton kernels
5253
pip install -r requirements-khd.txt
54+
55+
### Known Issues
56+
57+
These are currently some known issues not yet resolved
58+
- The design currently does a swap for the mixture-of-expert module with [ScatterMoE](./src/fms_acceleration_moe/utils/scattermoe.py). This affects the `state_dict` of the model, so any saved checkpoint may need to be converted back to original.
59+
- should eventually remove the dependency on an external `kernel-hyperdrive` repository.
60+
- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed.
61+
- currently only supports `StateDictType.SHARDED_STATE_DICT` because the implementation uses `DTensors` which have limited support for full state dicts. However for efficiency considerations, sharded state dicts are the most efficient.
62+
63+

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,19 @@
1919
# Third Party
2020
from fms_acceleration import AccelerationPlugin
2121
from transformers import AutoConfig, AutoModelForCausalLM
22-
import torch
22+
import torch
2323

24+
from .utils import (
25+
prepare_scattemoe, patch_huggingface_save_and_load_for_dtensors, patch_torch_optim_foreach_to_not_apply_to_dtensors
26+
)
2427

2528
# pylint: disable=too-many-instance-attributes
2629
class ScatterMoEAccelerationPlugin(AccelerationPlugin):
2730

2831
# NOTE: its not packaged properly so, "importlib.util.find_spec('khd')"
2932
# returns but "importlib.metadata.version('kernel-hyperdrive') is needed"
3033
# require_packages = {"khd"}
31-
34+
# NOTE: will address this later if we remove the dependency on kernel-hyperdrive
3235
restricted_model_archs = [
3336
'GraniteMoeForCausalLM', 'MixtralForCausalLM'
3437
]
@@ -49,11 +52,6 @@ def requires_custom_loading(self):
4952

5053
def model_loader(self, model_name: str, **kwargs):
5154

52-
# guarded
53-
# Local
54-
# pylint: disable=import-outside-toplevel
55-
from .utils import prepare_scattemoe
56-
5755
# load the model
5856
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
5957

@@ -62,19 +60,14 @@ def model_loader(self, model_name: str, **kwargs):
6260
world_size = torch.distributed.get_world_size()
6361
rank = torch.distributed.get_rank()
6462

65-
# shard the MOE, and store products required for
66-
# FSDP configuration
67-
# pylint: disable=unused-variable
63+
# shard the MOE, and store the component names, eventually needed
64+
# to configure the FSDP
6865
self._moe_component_module_names = prepare_scattemoe(
6966
model,
70-
# self._moe_component_cls,
7167
checkpoint_name_or_path=model_name,
7268
rank=rank,
7369
world_size=world_size,
7470
ep_degree=self._ep_degree,
75-
# shared_mesh_dim=self._shard_along_dp,
76-
# router_name=self._gate_module_name,
77-
# expert_name=self._experts_module_name,
7871
mixed_precision=False, # Currently this is hardcoded to OFF
7972
)
8073

@@ -93,17 +86,6 @@ def get_callbacks_and_ready_for_train(
9386
accelerator is not None
9487
and getattr(accelerator.state, "fsdp_plugin", None) is not None
9588
):
96-
# TODO: refactor
97-
# for newer torch that enables foreach for Dtensors we need to remove it
98-
from torch.optim.optimizer import _foreach_supported_types
99-
100-
i = 0
101-
while i < len(_foreach_supported_types):
102-
x = _foreach_supported_types[i]
103-
if x.__name__ == 'DTensor':
104-
_foreach_supported_types.pop(i)
105-
else:
106-
i += 1
10789

10890
# - use an internal function call to get the no split
10991
# module names, which are typically layers
@@ -115,22 +97,13 @@ def get_callbacks_and_ready_for_train(
11597
if layer.__class__.__name__ in _layers
11698
]
11799

118-
# Third Party
119-
# TODO: REFACTOR
120-
from fms_acceleration.model_patcher import patch_target_module
121-
122-
# Local
123-
from .utils.checkpoint_utils import (
124-
load_fsdp_model,
125-
load_fsdp_optimizer,
126-
save_fsdp_model,
127-
save_fsdp_optimizer,
128-
)
129-
130-
patch_target_module("transformers.trainer.save_fsdp_model", save_fsdp_model)
131-
patch_target_module("transformers.trainer.save_fsdp_optimizer", save_fsdp_optimizer)
132-
patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model)
133-
patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer)
100+
# call this to patch the HF save and load functions to be able
101+
# to save DTensors propery
102+
patch_huggingface_save_and_load_for_dtensors()
103+
104+
# call this to patch torch optim to not use
105+
# foreach for dtensors
106+
patch_torch_optim_foreach_to_not_apply_to_dtensors()
134107

135108
return callbacks
136109

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,38 @@
1-
from .scattermoe_prepare import prepare_scattemoe
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .scattermoe_prepare import prepare_scattemoe
16+
from .checkpoint_utils import patch_huggingface_save_and_load_for_dtensors
17+
18+
# this is a special patch function to disable foreach for
19+
# dtensors, which has been introduced since torch 2.4.
20+
# The reason is because this will cause problems in the optimizer
21+
# lerp.
22+
23+
def patch_torch_optim_foreach_to_not_apply_to_dtensors():
24+
# guarded.
25+
# this is an array of supported types, we will remove
26+
# dtensor from it, so the optimizer will faillback to per
27+
# parameter
28+
from torch.optim.optimizer import _foreach_supported_types
29+
30+
i = 0 # list index
31+
while i < len(_foreach_supported_types):
32+
x = _foreach_supported_types[i]
33+
if x.__name__ == 'DTensor':
34+
# pop from list
35+
_foreach_supported_types.pop(i)
36+
else:
37+
i += 1
38+

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,15 @@ def load_fsdp_optimizer(
150150
group["initial_lr"] = 0.0
151151
group["eps"] = 1e-8
152152
group["weight_decay"] = 0.0
153+
154+
# function to replace various trainer functions in HF with the ones
155+
# above
156+
def patch_huggingface_save_and_load_for_dtensors():
157+
# Third Party
158+
# NOTE: this is really a global replacement, which we use the patcher
159+
# to do
160+
from fms_acceleration.model_patcher import patch_target_module
161+
patch_target_module("transformers.trainer.save_fsdp_model", save_fsdp_model)
162+
patch_target_module("transformers.trainer.save_fsdp_optimizer", save_fsdp_optimizer)
163+
patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model)
164+
patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer)

0 commit comments

Comments
 (0)