Skip to content

Commit e094933

Browse files
authored
[shardformer] fix pipeline grad ckpt (#5620)
* [shardformer] fix pipeline grad ckpt
1 parent d83c633 commit e094933

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

colossalai/shardformer/shard/grad_ckpt_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
2222
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
2323
2424
"""
25+
2526
"""
2627
Args:
2728
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
@@ -49,6 +50,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
4950
num_stages: Optional[int] = None
5051
num_model_chunks: Optional[int] = None
5152
num_model_layers: Optional[int] = None
53+
num_layers_per_stage: Optional[List[int]] = None
5254
num_ckpt_layers_per_stage: Optional[List[int]] = None
5355

5456
def __post_init__(self):
@@ -70,6 +72,10 @@ def __post_init__(self):
7072
def _enable_gradient_checkpointing_ratio(self) -> bool:
7173
return self.gradient_checkpointing_ratio is not None
7274

75+
@property
76+
def _customize_num_layers_per_stage(self) -> bool:
77+
return self.num_layers_per_stage is not None and self.num_model_layers is not None
78+
7379
@property
7480
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
7581
return self.num_ckpt_layers_per_stage is not None

colossalai/shardformer/shard/shard_config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from colossalai.pipeline.stage_manager import PipelineStageManager
99

10-
from .grad_ckpt_config import GradientCheckpointConfig
10+
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
1111

1212
__all__ = ["ShardConfig"]
1313
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
@@ -30,6 +30,7 @@ class ShardConfig:
3030
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
3131
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
3232
"""
33+
3334
tensor_parallel_process_group: Optional[ProcessGroup] = None
3435
sequence_parallel_process_group: Optional[ProcessGroup] = None
3536
pipeline_stage_manager: Optional[PipelineStageManager] = None
@@ -104,6 +105,16 @@ def __post_init__(self):
104105
else:
105106
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
106107

108+
if (
109+
self.pipeline_stage_manager is not None
110+
and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig)
111+
and self.gradient_checkpoint_config._customize_num_layers_per_stage
112+
):
113+
self.pipeline_stage_manager.set_distribution_config(
114+
self.gradient_checkpoint_config.num_model_layers,
115+
self.gradient_checkpoint_config.num_layers_per_stage,
116+
)
117+
107118
def _turn_on_all_optimization(self):
108119
"""
109120
Turn on all optimization.

0 commit comments

Comments
 (0)