@@ -36,108 +36,31 @@ class ScatterMoEAccelerationPlugin(AccelerationPlugin):
3636 def __init__ (self , configurations : Dict [str , Dict ]):
3737 super ().__init__ (configurations )
3838
39- # arguments for configuring the mixture-of-experts model with defaults
40- # shown below for Mixtral 7x8b
41- # - 1. component class
42- # self._moe_component_cls = self._check_config_and_maybe_check_values(
43- # key="training.moe.scattermoe.moe_component_class",
44- # # default="MixtralSparseMoeBlock",
45- # default="GraniteMoeMoE",
46- # )
47- # - 2. gate_module_name
48- # self._gate_module_name = self._check_config_and_maybe_check_values(
49- # key="training.moe.scattermoe.moe_gate_module_name", default="gate"
50- # )
51- # # - 3. experts_module_name
52- # self._experts_module_name = self._check_config_and_maybe_check_values(
53- # key="training.moe.scattermoe.moe_experts_module_name", default="experts"
54- # )
55- # # - 4. mlp version
56- # self._mlp_version = self._check_config_and_maybe_check_values(
57- # key="training.moe.scattermoe.moe_mlp_impl",
58- # values=["v1", "v2"],
59- # default="v2",
60- # )
61-
62- # for controlling the type of sharding
63- # self._shard_along_dp = self._check_config_and_maybe_check_values(
64- # key="training.moe.scattermoe.shard_along_dp",
65- # values=[True, False],
66- # default=True,
67- # )
68-
69- # ep_size determines the expert parallel sharding
70- # - ep_size is ignored if _shard_along_dp=True
39+ # ep_degree determines the expert parallel sharding
40+ # - default of 1 means experts are not sharded and operate in pure replication.
7141 self ._ep_degree = self ._check_config_and_maybe_check_values (
7242 key = "training.moe.scattermoe.ep_degree" ,
7343 default = 1 ,
7444 )
7545
76- # for the moe_implementation, currently we only use the megablocks
77- # dropless sparse implementation
78- # self._moe_implementation = self._check_config_and_maybe_check_values(
79- # key="training.moe.scattermoe.moe_implementation",
80- # values=["dropless_sparse"],
81- # default="dropless_sparse",
82- # )
83- # self._moe_implementation = self._moe_implementation.split("_")[1]
84-
85- # self._load_balancing_loss = self._check_config_and_maybe_check_values(
86- # key="training.moe.scattermoe.load_balancing_loss",
87- # values=[True, False],
88- # default=False,
89- # )
90-
9146 @property
9247 def requires_custom_loading (self ):
9348 return True
9449
9550 def model_loader (self , model_name : str , ** kwargs ):
51+
9652 # guarded
9753 # Local
9854 # pylint: disable=import-outside-toplevel
99- # from .megablocks_utils.config_utils import update_mlp_registry
100- # from .megablocks_utils.shard_moe_utils import get_moe_kwargs, shard_moe
101-
102- # # - check the config
103- # if self._load_balancing_loss and not hasattr(
104- # AutoConfig.from_pretrained(model_name), "output_router_logits"
105- # ):
106- # warnings.warn(
107- # "load_balancing_loss=True but "
108- # "the model '{model_name}' config not have 'output_router_logits' "
109- # "in its config, hence it might not support load balancing and "
110- # "fallback to load_balancing_loss=False."
111- # )
112- # self._load_balancing_loss = False
113-
114- # this one does a forward patching on MLP, but needs to be fixed
115- # properly as the load balancing loss is currently not properly
116- # handled
117- # update_mlp_registry(
118- # self._moe_implementation, self._mlp_version, self._load_balancing_loss
119- # )
12055 from .utils import prepare_scattemoe
12156
122- # get additional parameters
123- # torch_dtype = kwargs.get("torch_dtype", torch.float32)
124-
12557 # load the model
12658 model = AutoModelForCausalLM .from_pretrained (model_name , ** kwargs )
12759
128- # set this in the config, which will be picked up by the forward
129- # function to go into the load_balancing loss
130- # model.config.output_router_logits = self._load_balancing_loss
131-
13260 rank , world_size = 0 , 1
13361 if torch .distributed .is_initialized ():
13462 world_size = torch .distributed .get_world_size ()
13563 rank = torch .distributed .get_rank ()
136- # else:
137- # # NOTE: or should we do a silent fallback
138- # raise AssertionError(
139- # "Megablocks expert parallel only works for distributed training."
140- # )
14164
14265 # shard the MOE, and store products required for
14366 # FSDP configuration
@@ -159,11 +82,6 @@ def model_loader(self, model_name: str, **kwargs):
15982 # flag from train_args. It will be better to handle this if
16083 # when we move the sharding to augmentation.
16184
162- # NOTE: Currently, it is a bit troublesome to pass the device_mesh to
163- # the FSDP constructor, so we do not do that.
164- # - therefore FSDP will always shard on world_size over the default process
165- # group
166-
16785 return model
16886
16987 def get_callbacks_and_ready_for_train (
0 commit comments