Skip to content

Commit 6d1e646

Browse files
committed
update readme, code cleanup, add comments and initial bench
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent f717606 commit 6d1e646

File tree

3 files changed

+21
-90
lines changed

3 files changed

+21
-90
lines changed

plugins/accelerated-moe/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# FMS Acceleration for Mixture-of-Experts
22

33
This library contains plugins to accelerate finetuning with the following optimizations:
4-
1. Expert-Parallel MoE with ScatterMoe & Megablocks
4+
1. Expert-Parallel MoE with Triton Kernels from ScatterMoe (& Megablocks)
55

66
## Plugins
77

@@ -20,12 +20,12 @@ Run the below in the top-level directory of this repo:
2020
tox -e run-benches \
2121
-x testenv:run-benches.deps+="-r plugins/accelerated-moe/requirements-khd.txt" \
2222
-- \
23-
"1 2" "4 8" benchmark_outputs scenarios-granite.yaml accelerated-moe-scatter
23+
"1 2 4 8" 128 benchmark_outputs scenarios-granite.yaml accelerated-moe-scatter
2424
```
2525
or run the larger `Mixtral-8x7B` bench:
2626
```
2727
tox ... \
28-
8 8 benchmark_outputs scenarios-granite.yaml accelerated-moe-scatter
28+
8 128 benchmark_outputs scenarios-granite.yaml accelerated-moe-scatter
2929
```
3030

3131
NOTE: if `FileNotFoundError` is observed on the *triton cache*, similar to issues like these:
@@ -40,10 +40,10 @@ running in `bash`:
4040
4141
source .tox/run-benches/bin/activate
4242
bash scripts/run_benchmarks.sh \
43-
"1 2" "4 8" benchmark_outputs scenarios-granite.yaml accelerated-moe-scatter
43+
....
4444
```
4545

46-
### Megablocks Dependencies
46+
### Triton Kernel Dependencies
4747

4848
Currently we do not copy the `scattermoe` kernels into this respository, to this is an additional manual install:
4949

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 3 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -36,108 +36,31 @@ class ScatterMoEAccelerationPlugin(AccelerationPlugin):
3636
def __init__(self, configurations: Dict[str, Dict]):
3737
super().__init__(configurations)
3838

39-
# arguments for configuring the mixture-of-experts model with defaults
40-
# shown below for Mixtral 7x8b
41-
# - 1. component class
42-
# self._moe_component_cls = self._check_config_and_maybe_check_values(
43-
# key="training.moe.scattermoe.moe_component_class",
44-
# # default="MixtralSparseMoeBlock",
45-
# default="GraniteMoeMoE",
46-
# )
47-
# - 2. gate_module_name
48-
# self._gate_module_name = self._check_config_and_maybe_check_values(
49-
# key="training.moe.scattermoe.moe_gate_module_name", default="gate"
50-
# )
51-
# # - 3. experts_module_name
52-
# self._experts_module_name = self._check_config_and_maybe_check_values(
53-
# key="training.moe.scattermoe.moe_experts_module_name", default="experts"
54-
# )
55-
# # - 4. mlp version
56-
# self._mlp_version = self._check_config_and_maybe_check_values(
57-
# key="training.moe.scattermoe.moe_mlp_impl",
58-
# values=["v1", "v2"],
59-
# default="v2",
60-
# )
61-
62-
# for controlling the type of sharding
63-
# self._shard_along_dp = self._check_config_and_maybe_check_values(
64-
# key="training.moe.scattermoe.shard_along_dp",
65-
# values=[True, False],
66-
# default=True,
67-
# )
68-
69-
# ep_size determines the expert parallel sharding
70-
# - ep_size is ignored if _shard_along_dp=True
39+
# ep_degree determines the expert parallel sharding
40+
# - default of 1 means experts are not sharded and operate in pure replication.
7141
self._ep_degree = self._check_config_and_maybe_check_values(
7242
key="training.moe.scattermoe.ep_degree",
7343
default=1,
7444
)
7545

76-
# for the moe_implementation, currently we only use the megablocks
77-
# dropless sparse implementation
78-
# self._moe_implementation = self._check_config_and_maybe_check_values(
79-
# key="training.moe.scattermoe.moe_implementation",
80-
# values=["dropless_sparse"],
81-
# default="dropless_sparse",
82-
# )
83-
# self._moe_implementation = self._moe_implementation.split("_")[1]
84-
85-
# self._load_balancing_loss = self._check_config_and_maybe_check_values(
86-
# key="training.moe.scattermoe.load_balancing_loss",
87-
# values=[True, False],
88-
# default=False,
89-
# )
90-
9146
@property
9247
def requires_custom_loading(self):
9348
return True
9449

9550
def model_loader(self, model_name: str, **kwargs):
51+
9652
# guarded
9753
# Local
9854
# pylint: disable=import-outside-toplevel
99-
# from .megablocks_utils.config_utils import update_mlp_registry
100-
# from .megablocks_utils.shard_moe_utils import get_moe_kwargs, shard_moe
101-
102-
# # - check the config
103-
# if self._load_balancing_loss and not hasattr(
104-
# AutoConfig.from_pretrained(model_name), "output_router_logits"
105-
# ):
106-
# warnings.warn(
107-
# "load_balancing_loss=True but "
108-
# "the model '{model_name}' config not have 'output_router_logits' "
109-
# "in its config, hence it might not support load balancing and "
110-
# "fallback to load_balancing_loss=False."
111-
# )
112-
# self._load_balancing_loss = False
113-
114-
# this one does a forward patching on MLP, but needs to be fixed
115-
# properly as the load balancing loss is currently not properly
116-
# handled
117-
# update_mlp_registry(
118-
# self._moe_implementation, self._mlp_version, self._load_balancing_loss
119-
# )
12055
from .utils import prepare_scattemoe
12156

122-
# get additional parameters
123-
# torch_dtype = kwargs.get("torch_dtype", torch.float32)
124-
12557
# load the model
12658
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
12759

128-
# set this in the config, which will be picked up by the forward
129-
# function to go into the load_balancing loss
130-
# model.config.output_router_logits = self._load_balancing_loss
131-
13260
rank, world_size = 0, 1
13361
if torch.distributed.is_initialized():
13462
world_size = torch.distributed.get_world_size()
13563
rank = torch.distributed.get_rank()
136-
# else:
137-
# # NOTE: or should we do a silent fallback
138-
# raise AssertionError(
139-
# "Megablocks expert parallel only works for distributed training."
140-
# )
14164

14265
# shard the MOE, and store products required for
14366
# FSDP configuration
@@ -159,11 +82,6 @@ def model_loader(self, model_name: str, **kwargs):
15982
# flag from train_args. It will be better to handle this if
16083
# when we move the sharding to augmentation.
16184

162-
# NOTE: Currently, it is a bit troublesome to pass the device_mesh to
163-
# the FSDP constructor, so we do not do that.
164-
# - therefore FSDP will always shard on world_size over the default process
165-
# group
166-
16785
return model
16886

16987
def get_callbacks_and_ready_for_train(
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
epoch,framework_config,gradient_accumulation_steps,mem_nvidia_mem_reserved,model_name_or_path,num_gpus,per_device_train_batch_size,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second
2+
0.25,none,16,70749.0,ibm/PowerMoE-3b,1,8,bfloat16,0.9477007621526718,2350.0523,5.447,0.043,1508.732
3+
0.25,none,8,46699.0,ibm/PowerMoE-3b,2,8,bfloat16,0.9477724695205688,1341.9179,9.539,0.075,1321.094
4+
0.25,none,4,38885.0,ibm/PowerMoE-3b,4,8,bfloat16,0.9478064042329788,712.2347,17.972,0.14,1244.534
5+
0.25,moe-scattermoe-granite-ep1,16,71049.0,ibm/PowerMoE-3b,1,8,bfloat16,0.9477236008644104,741.1462,17.271,0.135,4783.942
6+
0.25,moe-scattermoe-granite-ep1,8,52294.0,ibm/PowerMoE-3b,2,8,bfloat16,0.9511111199855804,484.7077,26.408,0.206,3657.462
7+
0.25,moe-scattermoe-granite-ep1,4,51251.5,ibm/PowerMoE-3b,4,8,bfloat16,0.9541541540622711,264.6776,48.361,0.378,3348.98
8+
,moe-scattermoe-granite-ep2,16,3.0,ibm/PowerMoE-3b,1,8,bfloat16,,,,,
9+
0.25,moe-scattermoe-granite-ep2,8,39854.0,ibm/PowerMoE-3b,2,8,bfloat16,0.9480846971273422,602.4418,21.247,0.166,2942.691
10+
0.25,moe-scattermoe-granite-ep2,4,40937.0,ibm/PowerMoE-3b,4,8,bfloat16,0.9512380701303482,305.5111,41.897,0.327,2901.367
11+
,moe-scattermoe-granite-ep4,16,3.0,ibm/PowerMoE-3b,1,8,bfloat16,,,,,
12+
,moe-scattermoe-granite-ep4,8,213.0,ibm/PowerMoE-3b,2,8,bfloat16,,,,,
13+
0.25,moe-scattermoe-granite-ep4,4,32128.0,ibm/PowerMoE-3b,4,8,bfloat16,0.9484522187709808,314.6519,40.68,0.318,2817.082

0 commit comments

Comments
 (0)