Skip to content

Commit 97a0bb4

Browse files
committed
fmt + lint
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent 83335c3 commit 97a0bb4

File tree

3 files changed

+63
-77
lines changed

3 files changed

+63
-77
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 53 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,13 @@
1313
# limitations under the License.
1414

1515
# Standard
16+
from collections import defaultdict
17+
from typing import List
18+
import json
1619
import os
1720
import re
18-
import json
19-
import torch
20-
21-
from .scattermoe_constants import (
22-
FILE_SAFETENSOR_INDEX,
23-
PARAM_NAME_WEIGHT_SCATTERMOE,
24-
PARAM_NAME_ROUTER_SCATTERMOE,
25-
get_scattermoe_conv_spec_from_archs,
26-
)
27-
28-
from .scattermoe_state_dict import get_checkpoint_meta_from_sharded_safetensor
2921

3022
# Third Party
31-
from transformers import PretrainedConfig
3223
from accelerate.logging import get_logger
3324
from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME
3425
from torch.distributed.checkpoint.default_planner import (
@@ -37,9 +28,18 @@
3728
)
3829
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
3930
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
31+
from transformers import PretrainedConfig
32+
import torch
4033
import torch.distributed.checkpoint as dcp
4134

42-
from typing import List
35+
# Local
36+
from .scattermoe_constants import (
37+
FILE_SAFETENSOR_INDEX,
38+
PARAM_NAME_ROUTER_SCATTERMOE,
39+
PARAM_NAME_WEIGHT_SCATTERMOE,
40+
get_scattermoe_conv_spec_from_archs,
41+
)
42+
from .scattermoe_state_dict import get_checkpoint_meta_from_sharded_safetensor
4343

4444
logger = get_logger(__name__)
4545

@@ -181,6 +181,7 @@ def patch_huggingface_save_and_load_for_dtensors():
181181
patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model)
182182
patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer)
183183

184+
184185
# trick to get the resolved cache file to acccess the safetensor
185186
# NOTE: this does not work if _dict_from_json_file, like GGUF files
186187
def get_resolved_checkpoint_location(model_name_or_path: str):
@@ -199,26 +200,29 @@ def _dict_from_json_file(resolved_config_file):
199200
PretrainedConfig._dict_from_json_file = _old_func
200201
return os.path.dirname(result)
201202

203+
202204
def restore_scattermoe_checkpoint_to_orig(
203205
dcp_checkpoint_dir: str,
204206
pretrained_model_name_or_path: str = None,
205-
dcp_outer_key: str = 'model',
207+
dcp_outer_key: str = "model",
206208
):
207-
208209
"""
209210
Parameters:
210211
dcp_checkpoint_dir (str): the dcp to be converted.
211-
pretrained_model_name_or_path (str): Optional, if provided we will
212-
use the hints to remap the
212+
pretrained_model_name_or_path (str): Optional, if provided we will
213+
use the hints to remap the
213214
"""
214215

215216
# reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
216217
# - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap
217218

218219
# guarded, load some internal functions
220+
# pylint: disable=import-outside-toplevel
221+
# Third Party
219222
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
220-
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
221223
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
224+
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
225+
222226
sd: STATE_DICT_TYPE = {}
223227
_load_state_dict(
224228
sd,
@@ -228,7 +232,7 @@ def restore_scattermoe_checkpoint_to_orig(
228232
)
229233
sd = sd[dcp_outer_key]
230234

231-
# if not provided
235+
# if not provided
232236
if pretrained_model_name_or_path is None:
233237
return sd
234238

@@ -237,18 +241,16 @@ def restore_scattermoe_checkpoint_to_orig(
237241
with open(os.path.join(loc, FILE_SAFETENSOR_INDEX), encoding="utf-8") as f:
238242
index = json.load(f)
239243

240-
# config
244+
# config
241245
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)
242246

243247
(
244-
moe_cls,
248+
_,
245249
router_name,
246250
expert_name,
247-
expert_mlp_spec,
251+
__,
248252
sharded_expert_ckpt,
249-
) = get_scattermoe_conv_spec_from_archs(
250-
config.architectures
251-
)
253+
) = get_scattermoe_conv_spec_from_archs(config.architectures)
252254

253255
# the sd from the module swap must have keys like
254256
# 'model.layers.0.block_sparse_moe.w1.weight'
@@ -258,13 +260,12 @@ def restore_scattermoe_checkpoint_to_orig(
258260
# prefix = model.layers.0 and module_name = block_sparse_moe
259261

260262
def _infer_prefixes_and_module_names(
261-
sd_keys: List[str], min_count: int = 3,
263+
sd_keys: List[str],
264+
min_count: int = 3,
262265
):
263-
_name = "|".join([
264-
PARAM_NAME_ROUTER_SCATTERMOE,
265-
*PARAM_NAME_WEIGHT_SCATTERMOE
266-
])
267-
_reg = re.compile(f'(.*)\.({_name})\.weight')
266+
_name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE])
267+
# pylint: disable=anomalous-backslash-in-string
268+
_reg = re.compile(f"(.*)\.({_name})\.weight")
268269
found = {}
269270

270271
for k in sd_keys:
@@ -276,22 +277,22 @@ def _infer_prefixes_and_module_names(
276277
found[prefix] = 1 + found.get(prefix, 0)
277278

278279
results = []
279-
for prefix in found.keys():
280+
for prefix, cnt in found.items():
280281
# if at least router, w1 and w2 are found, take it
281282
# otherwise we delete
282-
if found[prefix] >= min_count:
283+
if cnt >= min_count:
283284
results.append(prefix)
284285

285286
return results
286287

287288
for prefix in _infer_prefixes_and_module_names(sd.keys()):
288-
prefix = prefix.split('.')
289+
prefix = prefix.split(".")
289290
prefix, module_name = ".".join(prefix[:-1]), prefix[-1]
290291

291292
# checkpoint metadata is will be a map
292293
# key -> list of tuples
293294
# where each in the list is (param_name, stfile)
294-
# - if the list is larger than one, it means that the
295+
# - if the list is larger than one, it means that the
295296
# actual model has a sharded checkpoint
296297

297298
# defaultdict(list,
@@ -312,22 +313,12 @@ def _infer_prefixes_and_module_names(
312313
expert_name,
313314
)
314315

315-
# check if expert name has repeats
316-
# has_repeat = False
317-
# _seen = set()
318-
# for k in expert_name.split('|'):
319-
# if k in _seen:
320-
# has_repeat = True
321-
# break
322-
# _seen.add(k)
323-
324-
from collections import defaultdict
325316
model2scatter = defaultdict(dict)
326317
# construct a map of model_key -> {scatter_key: [params, ...]}
327-
# - if the param list > 1, that means many scatter keys map to 1
318+
# - if the param list > 1, that means many scatter keys map to 1
328319
# model param and they need to be cat
329320
for scatter_key, list_of_params in checkpoint_metadata.items():
330-
scatter_key_fqdn = '.'.join([prefix, module_name, scatter_key])
321+
scatter_key_fqdn = ".".join([prefix, module_name, scatter_key])
331322
scatter_param = sd[scatter_key_fqdn]
332323

333324
# remove from state dict
@@ -341,13 +332,11 @@ def _infer_prefixes_and_module_names(
341332
else:
342333
# if sharded, we just assume that there should be 1 expert
343334
# per shard
344-
assert n == scatter_param.shape[0], \
345-
"Sharded expert weights should be 1 expert per shard."
335+
assert (
336+
n == scatter_param.shape[0]
337+
), "Sharded expert weights should be 1 expert per shard."
346338

347-
if any(
348-
scatter_key.startswith(k) for k in
349-
PARAM_NAME_WEIGHT_SCATTERMOE
350-
):
339+
if any(scatter_key.startswith(k) for k in PARAM_NAME_WEIGHT_SCATTERMOE):
351340
scatter_param = scatter_param.permute(0, 2, 1)
352341

353342
# go through all the model keys
@@ -363,33 +352,33 @@ def _infer_prefixes_and_module_names(
363352
model2scatter[model_key][scatter_key] = _param
364353

365354
# replace them back in the sd
366-
for model_key in list(model2scatter.keys()):
355+
for model_key in list(model2scatter.keys()):
367356

368357
scatter_params = model2scatter[model_key]
369358

370359
# - there is an assumption that the ifthere is a cat, then
371360
# it will go by order of scatter keys
372361
scatter_keys = sorted(scatter_params.keys())
373362

374-
assert len(scatter_keys) > 0, f"Obtained zero scatter keys for model_key \'{model_key}\'"
363+
assert (
364+
len(scatter_keys) > 0
365+
), f"Obtained zero scatter keys for model_key '{model_key}'"
375366

376367
if len(scatter_keys) == 1:
377368
sd[model_key] = scatter_params[scatter_keys[0]]
378369
else:
379-
# unfortunately, there this is a in
370+
# unfortunately, there this is a in
380371
# scattermoe_state_dict._maybe_reshape_scattermoe_expert_weights
381372
# that we split on the dim=1, so we cat back on that
382373
sd[model_key] = torch.cat(
383-
[scatter_params[k] for k in scatter_keys],
384-
dim=1
374+
[scatter_params[k] for k in scatter_keys], dim=1
385375
)
386376

387377
# remove from this intemediate mapping
388378
del model2scatter[model_key]
389379

390-
rem_keys = ",".join([k for k in model2scatter])
391-
assert len(rem_keys) == 0, \
392-
f"Did not handle model parameters \'{rem_keys}\'"
380+
rem_keys = ",".join(list(model2scatter))
381+
assert len(rem_keys) == 0, f"Did not handle model parameters '{rem_keys}'"
393382

394383
return sd
395384

@@ -398,7 +387,8 @@ def _infer_prefixes_and_module_names(
398387

399388

400389
# have it serve as a conversion script
401-
if __name__ == '__main__':
390+
if __name__ == "__main__":
391+
# Standard
402392
import argparse
403393

404394
parser = argparse.ArgumentParser(
@@ -417,7 +407,5 @@ def _infer_prefixes_and_module_names(
417407
"In order to reconstruct the state dict, we requre hints from "
418408
"the original pretrained model checkpoint (from which this "
419409
"checkpoint is obtained)."
420-
)
410+
),
421411
)
422-
423-

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# Standard
1516
from typing import List
1617

1718
# to be updated so that the parsers can work properly
18-
PARAM_NAME_ROUTER_SCATTERMOE = 'router'
19-
PARAM_NAME_WEIGHT_SCATTERMOE = ['w1', 'w2', 'w3']
19+
PARAM_NAME_ROUTER_SCATTERMOE = "router"
20+
PARAM_NAME_WEIGHT_SCATTERMOE = ["w1", "w2", "w3"]
2021

2122
FILE_SAFETENSOR_INDEX = "model.safetensors.index.json"
2223
KEY_REPLICATE = "replicate"
@@ -79,9 +80,8 @@
7980

8081
# helper function to get the spec based on architectures
8182

82-
def get_scattermoe_conv_spec_from_archs(
83-
architectures: List[str]
84-
):
83+
84+
def get_scattermoe_conv_spec_from_archs(architectures: List[str]):
8585
# infer the spec
8686
for archs, spec in SCATTERMOE_CONVERSION_SPEC.items():
8787
archs = archs.split(",")
@@ -90,6 +90,6 @@ def get_scattermoe_conv_spec_from_archs(
9090

9191
# if not found
9292
raise ValueError(
93-
f"In order to configure ScatterMoe for archs \'{architectures}\' "
93+
f"In order to configure ScatterMoe for archs '{architectures}' "
9494
"the conversion spect must be updated in scattermoe_constants.py"
95-
)
95+
)

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch
3131

3232
# Local
33+
from .checkpoint_utils import get_resolved_checkpoint_location
3334
from .scattermoe_constants import (
3435
FILE_SAFETENSOR_INDEX,
3536
KEY_EXPERT_PARALLEL,
@@ -43,7 +44,6 @@
4344
get_state_dict_from_checkpoint_metadata,
4445
)
4546

46-
from .checkpoint_utils import get_resolved_checkpoint_location
4747

4848
# this function will load the sharded experts onto the device.
4949
# - this assumes that the "dmoe" module is the megablocks.layers.dmoe.dMoE distributed
@@ -72,7 +72,7 @@ def _hook(grad):
7272
# if its the router, replicate
7373
param = distribute_tensor(param, device_mesh, reps + [Replicate()])
7474
elif param.shape[0] > num_experts_per_device:
75-
# if its a weight param and the number of experts exceed that of
75+
# if its a weight param and the number of experts exceed that of
7676
# the device, shard
7777
param = distribute_tensor(param, device_mesh, reps + [Shard(0)])
7878
else:
@@ -138,9 +138,7 @@ def prepare_scattermoe(
138138
expert_name,
139139
expert_mlp_spec,
140140
sharded_expert_ckpt,
141-
) = get_scattermoe_conv_spec_from_archs(
142-
model.config.architectures
143-
)
141+
) = get_scattermoe_conv_spec_from_archs(model.config.architectures)
144142

145143
# split the names first
146144
expert_name = expert_name.split("|")

0 commit comments

Comments
 (0)