Skip to content

Commit bea38df

Browse files
committed
support ep for fsdp
1 parent f48c4af commit bea38df

File tree

10 files changed

+485
-41
lines changed

10 files changed

+485
-41
lines changed

configs/7B_sft.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,15 @@
2222
CHECKPOINT_EVERY = 50
2323
ckpt = dict(
2424
enable_save_ckpt=False, # enable ckpt save.
25-
enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
2625
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
27-
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
28-
load_ckpt_folder="local:llm_ckpts/",
29-
# 'load_ckpt_info' setting guide:
30-
# 1. the 'path' indicate ckpt path,
31-
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
32-
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
33-
# load function such as "llama"
34-
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
3526
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
3627
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
3728
# with an automatic restart mechanism upon training reboot.
3829
# Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
3930
# path specified in `load_ckpt_info` by default.
4031
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
4132
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
42-
auto_resume=True,
33+
auto_resume=False,
4334
checkpoint_every=CHECKPOINT_EVERY,
4435
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
4536
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
@@ -144,14 +135,12 @@
144135
model = dict(
145136
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
146137
num_attention_heads=NUM_ATTENTION_HEAD,
147-
embed_split_hidden=True,
148138
vocab_size=VOCAB_SIZE,
149139
embed_grad_scale=1,
150140
parallel_output=True,
151141
hidden_size=HIDDEN_SIZE,
152142
num_layers=NUM_LAYER,
153143
mlp_ratio=MLP_RATIO,
154-
apply_post_layer_norm=False,
155144
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
156145
norm_type="rmsnorm",
157146
layer_norm_epsilon=1e-5,

internlm/checkpoint/checkpoint_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def try_resume_training(self, train_state: TrainState, current_time=""):
582582
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
583583
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
584584
)
585-
elif is_using_fsdp() and is_using_hf() and not self.auto_resume:
585+
elif is_using_fsdp() and not self.auto_resume:
586586
pass
587587
else:
588588
load_path = self.load_ckpt_info["path"]

internlm/core/fsdp.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
FSDP2_SUPPORTED = False
3434

3535
try:
36+
import torch.distributed.checkpoint as dcp
3637
from torch.distributed.checkpoint.state_dict import (
3738
StateDictOptions,
39+
get_model_state_dict,
3840
set_model_state_dict,
3941
)
4042

@@ -163,8 +165,29 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
163165
)
164166
fsdp_mode = gpc.config.parallel.fsdp.get("mode", "v1")
165167
fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda")
168+
if gpc.is_using_parallel_mode(ParallelMode.EXPERT):
169+
assert gpc.get_world_size(ParallelMode.EXPERT_DATA) * gpc.get_world_size(ParallelMode.EXPERT) == gpc.get_world_size(ParallelMode.GLOBAL)
166170

167171
if fsdp_mode == "v1":
172+
ignored_mod = []
173+
if gpc.is_using_parallel_mode(ParallelMode.EXPERT):
174+
for layer_id, layer in enumerate(model.model.layers):
175+
if layer_id >= gpc.config.model.first_k_dense_replace:
176+
# Should follow this modeling pattern if EP is enabled.
177+
# Change the expert module name if needed.
178+
# TODO: Make this part hard-coded or config-driven?
179+
layer.feed_forward.moe_layer.experts = FSDP(
180+
layer.feed_forward.moe_layer.experts,
181+
process_group=gpc.get_group(ParallelMode.EXPERT_DATA),
182+
sharding_strategy=ShardingStrategy.FULL_SHARD,
183+
sync_module_states=fsdp_init_method != "cuda", # sync model paramters
184+
forward_prefetch=True,
185+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
186+
limit_all_gathers=True,
187+
use_orig_params=True,
188+
device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states
189+
)
190+
ignored_mod.append(layer.feed_forward.moe_layer.experts)
168191
model = FSDP(
169192
module=model,
170193
process_group=gpc.get_group(ParallelMode.GLOBAL),
@@ -176,6 +199,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
176199
limit_all_gathers=True,
177200
use_orig_params=True,
178201
device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states
202+
ignored_modules=ignored_mod,
179203
)
180204
# For FSDP v1, to get ckpt resuming work normally, we do dummy forward.
181205
# This hack is needed due to FSDP v1 lazy initialization in model construction.
@@ -196,7 +220,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
196220
else:
197221
raise ValueError(f"Unsupported FSDP mode: {fsdp_mode}")
198222

199-
if is_using_hf() and not gpc.config.ckpt.get("auto_resume", False):
223+
if not gpc.config.ckpt.get("auto_resume", False):
200224
load_ckpt_info = gpc.config.ckpt.load_ckpt_info
201225
load_ckpt_path = load_ckpt_info.get("path", None)
202226
load_ckpt_content = load_ckpt_info.get("content", [])
@@ -205,16 +229,22 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
205229
"model",
206230
), "If auto_resume=False and checkpoint path is given, only model can be loaded"
207231
if DCP_SUPPORTED:
208-
hf = gpc.config.hf
209-
mod = LazyObject(hf.mod, hf.mod_cls)
210-
mod = mod.build()
211-
state_dict = mod.from_pretrained(
212-
pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True
213-
).state_dict()
214-
state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict}
215-
set_model_state_dict(
216-
model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True)
217-
)
232+
if is_using_hf():
233+
hf = gpc.config.hf
234+
mod = LazyObject(hf.mod, hf.mod_cls)
235+
mod = mod.build()
236+
state_dict = mod.from_pretrained(
237+
pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True
238+
).state_dict()
239+
state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict}
240+
set_model_state_dict(
241+
model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True)
242+
)
243+
else:
244+
state_dict = get_model_state_dict(model=model)
245+
state_dict = {key: state_dict[key].clone().detach() for key in state_dict}
246+
dcp.load(state_dict=state_dict, checkpoint_id=load_ckpt_path)
247+
set_model_state_dict(model=model, model_state_dict=state_dict)
218248
del state_dict
219249
internlm_accelerator.empty_cache()
220250
else:

internlm/initialize/initialize_launcher.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def dispatch_hf_config_before_launch(hf: dict) -> None:
5757
gpc.config.model.num_experts = model_config.num_experts
5858
elif hasattr(model_config, "n_routed_experts"):
5959
gpc.config.model.num_experts = model_config.n_routed_experts
60+
if hasattr(model_config, "first_k_dense_replace"):
61+
gpc.config.model.first_k_dense_replace = model_config.first_k_dense_replace
6062

6163

6264
def args_sanity_check():
@@ -306,8 +308,9 @@ def args_sanity_check():
306308
logger.info(f"clip_grad_norm: {clip_grad_norm}")
307309

308310
model = gpc.config.model
309-
if "enable_qkv_fusion" not in model:
310-
model._add_item("enable_qkv_fusion", True)
311+
# TODO: should we set default value for enable_qkv_fusion?
312+
# if "enable_qkv_fusion" not in model:
313+
# model._add_item("enable_qkv_fusion", True)
311314

312315
if "dtype" not in model:
313316
logger.warning("dtype is not set, use torch.float16 by defalut!")

internlm/initialize/initialize_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def split_params_into_different_groups_for_optimizer(
5050

5151
if is_using_fsdp():
5252
optimizer_mode = ParallelMode.GLOBAL
53-
optimizer_mode_expert = ParallelMode.GLOBAL
53+
optimizer_mode_expert = ParallelMode.EXPERT_DATA
5454
else:
5555
optimizer_mode = ParallelMode.ZERO1
5656
optimizer_mode_expert = ParallelMode.EXPERT_DATA

internlm/model/model_ops/ops/cross_entropy.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
CrossEntropyApexVocabParallel,
1919
CrossEntropyLossApex,
2020
CrossEntropyPython,
21+
CrossEntropyLossFlash,
2122
)
2223
from internlm.utils.logger import get_logger
2324

@@ -86,17 +87,8 @@ def new_cross_entropy(
8687

8788
assert gpc.get_group(ParallelMode.TENSOR) is not None, "The process group should not be None."
8889

89-
try:
90-
from flash_attn.losses.cross_entropy import (
91-
CrossEntropyLoss as FlashCrossEntropyLoss,
92-
)
93-
94-
flash_cross_entropy_impl = True
95-
except (ModuleNotFoundError, ImportError):
96-
flash_cross_entropy_impl = False
97-
9890
assert (
99-
gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl
91+
gpc.config.model.get("use_flash_attn", False)
10092
), "Only flash cross entropy support parallel_output"
10193

10294
assert (
@@ -108,7 +100,7 @@ def new_cross_entropy(
108100
which may result loss divergency in long sequence."
109101
)
110102

111-
return FlashCrossEntropyLoss(
103+
return CrossEntropyLossFlash(
112104
ignore_index=ignore_index,
113105
reduction=reduction,
114106
label_smoothing=label_smoothing,

internlm/model/model_ops/ops/cross_entropy_ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from .py_naive_loss import CrossEntropyPython
33
from .py_vocab_parallel_loss import CrossEntropyApexVocabParallel
44
from .sequence_parallel_loss import VocabSequenceParallelCrossEntropyLoss
5+
from .flash_loss import CrossEntropyLossFlash
56

67
__all__ = [
78
"CrossEntropyLossApex",
89
"CrossEntropyPython",
910
"CrossEntropyApexVocabParallel",
1011
"VocabSequenceParallelCrossEntropyLoss",
12+
"CrossEntropyLossFlash",
1113
]

0 commit comments

Comments
 (0)