Skip to content

Commit 17062c8

Browse files
[hotfix] fix hybrid checkpointio for sp+dp (#6184)
* Update hybrid_parallel_plugin.py * Update hybrid_parallel_plugin.py * Update hybrid_parallel_plugin.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update build_on_pr.yml * Update test_zerobubble_pp.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ca0aa23 commit 17062c8

File tree

6 files changed

+35
-30
lines changed

6 files changed

+35
-30
lines changed

.github/workflows/build_on_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ jobs:
199199
fi
200200
201201
- name: Upload test coverage artifact
202-
uses: actions/upload-artifact@v3
202+
uses: actions/upload-artifact@v4
203203
with:
204204
name: report
205205
path: report/

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,15 @@ def __init__(
11881188
else:
11891189
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
11901190

1191+
# sync gradients across DP * SP ranks
1192+
# sync gradients across DP * SP ranks
1193+
# Apply Hybrid ZeRO across DP * SP ranks
1194+
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
1195+
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
1196+
self.dp_size = get_world_size(self.mixed_dp_group)
1197+
else:
1198+
self.mixed_dp_group = self.dp_group
1199+
11911200
self.shard_config = ShardConfig(
11921201
tensor_parallel_process_group=self.tp_group,
11931202
sequence_parallel_process_group=self.sp_group,
@@ -1298,19 +1307,11 @@ def configure(
12981307
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
12991308
self.dp_size == 1 and self.pp_size == 1
13001309
)
1301-
# sync gradients across DP * SP ranks
1302-
# sync gradients across DP * SP ranks
1303-
# Apply Hybrid ZeRO across DP * SP ranks
1304-
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
1305-
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
1306-
self.dp_size = get_world_size(dp_group)
1307-
else:
1308-
dp_group = self.dp_group
13091310
model = HybridParallelModule(
13101311
model,
13111312
precision=self.precision,
13121313
shard_config=self.shard_config,
1313-
dp_group=dp_group,
1314+
dp_group=self.mixed_dp_group,
13141315
tp_group=self.tp_group,
13151316
sp_group=self.sp_group,
13161317
use_ddp=use_ddp,
@@ -1359,7 +1360,7 @@ def configure(
13591360
model,
13601361
use_pipeline=self.enable_pipeline_parallelism,
13611362
param_info=param_info,
1362-
dp_process_group=dp_group,
1363+
dp_process_group=self.mixed_dp_group,
13631364
tp_process_group=self.tp_group,
13641365
pp_process_group=self.pp_group,
13651366
verbose=True,
@@ -1488,7 +1489,9 @@ def seed_worker(worker_id):
14881489
)
14891490

14901491
def get_checkpoint_io(self) -> CheckpointIO:
1491-
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)
1492+
return HybridParallelCheckpointIO(
1493+
self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage
1494+
)
14921495

14931496
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
14941497
assert (

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,14 @@ def __init__(
351351
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
352352
else:
353353
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
354+
355+
# sync gradients across DP * SP ranks
356+
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
357+
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
358+
self.dp_size = dist.get_world_size(self.mixed_dp_group)
359+
else:
360+
self.mixed_dp_group = self.dp_group
361+
354362
self.use_fp8 = use_fp8
355363

356364
self.shard_config = ShardConfig(
@@ -404,7 +412,7 @@ def __init__(
404412

405413
def get_checkpoint_io(self) -> MoECheckpointIO:
406414
return MoECheckpointIO(
407-
self.dp_group,
415+
self.mixed_dp_group,
408416
self.pp_group,
409417
self.tp_group,
410418
self.sp_group,
@@ -435,20 +443,14 @@ def configure(
435443
and self.sequence_parallelism_mode == "all_to_all"
436444
)
437445

438-
# sync gradients across DP * SP ranks
439-
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
440-
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
441-
else:
442-
dp_group = self.dp_group
443-
444446
if use_ddp:
445447
self.logger.warning(
446448
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
447449
ranks=[0],
448450
)
449451
self.ddp_config["find_unused_parameters"] = True
450452

451-
if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
453+
if dist.get_process_group_ranks(self.mixed_dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
452454
raise ValueError(
453455
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
454456
)
@@ -457,7 +459,7 @@ def configure(
457459
module=model,
458460
precision=self.precision,
459461
shard_config=self.shard_config,
460-
dp_group=dp_group,
462+
dp_group=self.mixed_dp_group,
461463
tp_group=self.tp_group,
462464
sp_group=self.sp_group,
463465
use_ddp=use_ddp,
@@ -507,7 +509,7 @@ def configure(
507509
model,
508510
use_pipeline=self.enable_pipeline_parallelism,
509511
param_info=param_info,
510-
dp_process_group=dp_group,
512+
dp_process_group=self.mixed_dp_group,
511513
tp_process_group=self.tp_group,
512514
pp_process_group=self.pp_group,
513515
moe_dp_group=self.moe_dp_group,

tests/test_pipeline/test_schedule/test_zerobubble_pp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
885885
parallel_optimizer.backward(parallel_output)
886886
parallel_optimizer.step()
887887
parallel_optimizer.zero_grad()
888-
dist.all_reduce(parallel_output, group=plugin.dp_group)
888+
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
889889

890890
# ===================================================================================
891891
# run normal model with all dp(different) inputs
892892
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
893-
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
893+
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
894894
torch_output_sum = 0
895895
for input_data_ in all_inputs:
896896
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
@@ -1040,12 +1040,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
10401040
parallel_optimizer.backward(parallel_output)
10411041
parallel_optimizer.step()
10421042
parallel_optimizer.zero_grad()
1043-
dist.all_reduce(parallel_output, group=plugin.dp_group)
1043+
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
10441044

10451045
# ===================================================================================
10461046
# run normal model with all dp(different) inputs
10471047
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
1048-
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
1048+
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
10491049
torch_output_sum = 0
10501050
for input_data_ in all_inputs:
10511051
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()

tests/test_shardformer/test_model/test_shard_deepseek.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,12 @@ def run_deepseek_commom(parallel_config: Tuple[int, ...]):
125125
parallel_optimizer.backward(parallel_output)
126126
parallel_optimizer.step()
127127
parallel_optimizer.zero_grad()
128-
dist.all_reduce(parallel_output, group=plugin.dp_group)
128+
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
129129

130130
# ===================================================================================
131131
# run normal model with all dp(different) inputs
132132
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
133-
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
133+
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
134134
torch_output_sum = 0
135135
for input_data_ in all_inputs:
136136
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()

tests/test_shardformer/test_model/test_shard_mixtral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ def run_mixtral_commom(config: Tuple[int, ...]):
118118
parallel_optimizer.backward(parallel_output)
119119
parallel_optimizer.step()
120120
parallel_optimizer.zero_grad()
121-
dist.all_reduce(parallel_output, group=plugin.dp_group)
121+
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
122122

123123
# ===================================================================================
124124
# run normal model with all dp(different) inputs
125125
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
126-
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
126+
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
127127
torch_output_sum = 0
128128
for input_data_ in all_inputs:
129129
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()

0 commit comments

Comments
 (0)