Skip to content

Commit 1b387ca

Browse files
authored
[shardformer] refactor pipeline grad ckpt config (#5646)
* [shardformer] refactor pipeline grad ckpt config * [shardformer] refactor pipeline grad ckpt config * [pipeline] fix stage manager
1 parent 7ef9160 commit 1b387ca

File tree

11 files changed

+59
-102
lines changed

11 files changed

+59
-102
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ def __init__(
983983
custom_policy: Policy = None,
984984
pp_style: str = "1f1b",
985985
num_model_chunks: int = 1,
986+
num_layers_per_stage: Optional[List[int]] = None,
986987
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
987988
enable_metadata_cache: bool = True,
988989
make_vocab_size_divisible_by: int = 64,
@@ -1056,6 +1057,7 @@ def __init__(
10561057
pipeline_axis=self.pp_axis,
10571058
enable_interleave=pp_style == "interleaved",
10581059
num_model_chunks=num_model_chunks,
1060+
num_layers_per_stage=num_layers_per_stage,
10591061
)
10601062

10611063
if pp_style == "interleaved":

colossalai/pipeline/stage_manager.py

Lines changed: 29 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,18 @@ def __init__(
2727
pipeline_axis: int,
2828
enable_interleave: bool = False,
2929
num_model_chunks: int = 1,
30+
num_layers_per_stage: Optional[List[int]] = None,
3031
) -> None:
3132
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
3233

33-
self.num_layers_per_stage = None
34-
3534
self.pg_mesh = pg_mesh
3635
self.pipeline_axis = pipeline_axis
3736
self.prev_rank: Optional[Tuple[int, ...]] = None
3837
self.next_rank: Optional[Tuple[int, ...]] = None
3938
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
39+
if num_layers_per_stage is not None:
40+
assert len(num_layers_per_stage) == self.num_stages
41+
self.num_layers_per_stage = num_layers_per_stage
4042

4143
# init prev and next coord
4244
coord = self.pg_mesh.coordinate()
@@ -56,6 +58,8 @@ def __init__(
5658
self.p2p_groups[tuple(ranks_in_group)] = group
5759

5860
self.is_interleave = enable_interleave
61+
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
62+
self.num_model_chunks: int = num_model_chunks
5963
if enable_interleave:
6064
# use circle p2p communication
6165
# add the process group of the first rank and the last rank
@@ -64,59 +68,11 @@ def __init__(
6468
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
6569
self.p2p_groups[tuple(ranks_in_group)] = group
6670

67-
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
68-
self.num_model_chunks: int = num_model_chunks
69-
7071
# for shardformer, hold stage indices of model
7172
self.stage_indices: List[Tuple[int, int]]
7273
# for shardformer, hold model chunk id
7374
self.model_chunk_id: Optional[int] = None
7475

75-
@property
76-
def control_distribute_layers(self) -> bool:
77-
return self.num_layers_per_stage is not None
78-
79-
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
80-
"""Set the distribution configuration.
81-
This allows user to customize the number of layers for each stage.
82-
83-
Args:
84-
num_model_layers (int): Number of layers in the model.
85-
num_layers_per_stage (List[int]): Number of layers for each stage.
86-
"""
87-
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
88-
assert sum(num_layers_per_stage) == num_model_layers
89-
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
90-
self.num_model_layers = num_model_layers
91-
self.num_layers_per_stage = num_layers_per_stage
92-
93-
def distribute_layers(
94-
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
95-
) -> List[int]:
96-
"""Divide layers into stages"""
97-
num_stages = self.num_stages if num_stages is None else num_stages
98-
num_model_chunks = (
99-
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
100-
)
101-
102-
if self.control_distribute_layers:
103-
assert num_layers == self.num_model_layers
104-
return self.num_layers_per_stage
105-
106-
else:
107-
quotient = num_layers // (num_stages * num_model_chunks)
108-
remainder = num_layers % (num_stages * num_model_chunks)
109-
110-
# calculate the num_layers per stage
111-
layers_per_stage = [quotient] * num_stages * num_model_chunks
112-
113-
# deal with the rest layers
114-
if remainder > 0:
115-
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
116-
for i in range(start_position, start_position + remainder):
117-
layers_per_stage[i] += 1
118-
return layers_per_stage
119-
12076
def get_stage_index(
12177
self,
12278
layers_per_stage: List[int],
@@ -139,9 +95,7 @@ def get_stage_index(
13995
14096
"""
14197
stage = self.stage if stage is None else stage
142-
num_model_chunks = (
143-
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
144-
)
98+
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
14599
num_stages = self.num_stages if num_stages is None else num_stages
146100

147101
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
@@ -261,3 +215,25 @@ def switch_model_chunk_id(self, model_chunk_id: int):
261215
self.model_chunk_id = model_chunk_id
262216
yield
263217
self.model_chunk_id = old_model_chunk_id
218+
219+
def distribute_layers(
220+
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
221+
) -> List[int]:
222+
if self.num_layers_per_stage is not None:
223+
assert sum(self.num_layers_per_stage) == num_layers
224+
return self.num_layers_per_stage
225+
226+
num_stages = self.num_stages if num_stages is None else num_stages
227+
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
228+
quotient = num_layers // (num_stages * num_model_chunks)
229+
remainder = num_layers % (num_stages * num_model_chunks)
230+
231+
# calculate the num_layers per stage
232+
layers_per_stage = [quotient] * num_stages * num_model_chunks
233+
234+
# deal with the rest layers
235+
if remainder > 0:
236+
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
237+
for i in range(start_position, start_position + remainder):
238+
layers_per_stage[i] += 1
239+
return layers_per_stage

colossalai/shardformer/modeling/llama.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,10 @@ def llama_model_forward(
168168
if shard_config.gradient_checkpoint_config is not None:
169169
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
170170
stage=stage_manager.stage,
171+
num_stages=stage_manager.num_stages,
171172
num_layers=end_idx - start_idx,
172173
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
174+
num_model_chunks=stage_manager.num_model_chunks,
173175
)
174176
assert num_ckpt_layers <= end_idx - start_idx
175177

colossalai/shardformer/modeling/mistral.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,10 @@ def mistral_model_forward(
129129
if shard_config.gradient_checkpoint_config is not None:
130130
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
131131
stage=stage_manager.stage,
132+
num_stages=stage_manager.num_stages,
132133
num_layers=end_idx - start_idx,
133134
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
135+
num_model_chunks=stage_manager.num_model_chunks,
134136
)
135137
assert num_ckpt_layers <= end_idx - start_idx
136138

colossalai/shardformer/policies/base_policy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class SubModuleReplacementDescription:
2828
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
2929
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
3030
"""
31+
3132
suffix: str
3233
target_module: Union[ParallelModule, BaseLayerNorm]
3334
kwargs: Dict[str, Any] = None
@@ -54,6 +55,7 @@ def example_replace_weight(module: torch.nn.Module):
5455
object which specifies the module to be replaced and the target module used to replacement.
5556
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
5657
"""
58+
5759
attribute_replacement: Dict[str, Any] = None
5860
param_replacement: List[Callable] = None
5961
sub_module_replacement: List[SubModuleReplacementDescription] = None

colossalai/shardformer/shard/grad_ckpt_config.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,46 +47,33 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
4747
...
4848
4949
"""
50-
num_stages: Optional[int] = None
51-
num_model_chunks: Optional[int] = None
52-
num_model_layers: Optional[int] = None
53-
num_layers_per_stage: Optional[List[int]] = None
5450
num_ckpt_layers_per_stage: Optional[List[int]] = None
5551

5652
def __post_init__(self):
57-
if self._enable_gradient_checkpointing_ratio:
53+
if self._enable_customized_ckpt_layers_per_stage:
54+
assert all([num_ckpt_layers >= 0 for num_ckpt_layers in self.num_ckpt_layers_per_stage])
55+
elif self._enable_gradient_checkpointing_ratio:
5856
if not (0 <= self.gradient_checkpointing_ratio <= 1):
5957
raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
6058

61-
if self._enable_customized_ckpt_layers_per_stage:
62-
assert (
63-
self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
64-
)
65-
assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
66-
assert all(
67-
[0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
68-
)
69-
self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
70-
7159
@property
7260
def _enable_gradient_checkpointing_ratio(self) -> bool:
7361
return self.gradient_checkpointing_ratio is not None
7462

75-
@property
76-
def _customize_num_layers_per_stage(self) -> bool:
77-
return self.num_layers_per_stage is not None and self.num_model_layers is not None
78-
7963
@property
8064
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
8165
return self.num_ckpt_layers_per_stage is not None
8266

83-
def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int:
67+
def get_num_ckpt_layers(
68+
self, stage: int, num_stages: int, num_layers: int, model_chunk_id: int = 0, num_model_chunks: int = 1
69+
) -> int:
8470
if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
8571
raise RuntimeError("No checkpointed layers information is provided")
8672

8773
if self._enable_customized_ckpt_layers_per_stage:
88-
assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks
89-
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages]
74+
assert len(self.num_ckpt_layers_per_stage) == num_stages * num_model_chunks
75+
assert stage <= num_stages and model_chunk_id <= num_model_chunks
76+
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * num_stages]
9077
assert num_ckpt_layers <= num_layers
9178
return num_ckpt_layers
9279
else:

colossalai/shardformer/shard/shard_config.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from colossalai.pipeline.stage_manager import PipelineStageManager
99

10-
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
10+
from .grad_ckpt_config import GradientCheckpointConfig
1111

1212
__all__ = ["ShardConfig"]
1313
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
@@ -105,16 +105,6 @@ def __post_init__(self):
105105
else:
106106
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
107107

108-
if (
109-
self.pipeline_stage_manager is not None
110-
and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig)
111-
and self.gradient_checkpoint_config._customize_num_layers_per_stage
112-
):
113-
self.pipeline_stage_manager.set_distribution_config(
114-
self.gradient_checkpoint_config.num_model_layers,
115-
self.gradient_checkpoint_config.num_layers_per_stage,
116-
)
117-
118108
def _turn_on_all_optimization(self):
119109
"""
120110
Turn on all optimization.

examples/language/llama/benchmark.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,15 @@ def empty_init():
8888
pass
8989

9090
# ckpt config for LLaMA3-70B on 64 H100 GPUs
91-
ckpt_config = (
92-
PipelineGradientCheckpointConfig(
93-
num_stages=args.pp,
94-
num_model_chunks=1,
95-
num_model_layers=80,
96-
num_layers_per_stage=[19, 20, 20, 21],
97-
num_ckpt_layers_per_stage=[19, 19, 19, 13],
98-
)
91+
hybrid_kwargs = (
92+
{
93+
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
94+
num_ckpt_layers_per_stage=[19, 19, 19, 13],
95+
),
96+
"num_layers_per_stage": [19, 20, 20, 21],
97+
}
9998
if args.custom_ckpt
100-
else None
99+
else {}
101100
)
102101

103102
# ==============================
@@ -173,7 +172,7 @@ def empty_init():
173172
microbatch_size=args.mbs,
174173
precision="bf16",
175174
dp_outside=False,
176-
gradient_checkpoint_config=ckpt_config,
175+
**hybrid_kwargs,
177176
)
178177
elif args.plugin == "3d_cpu":
179178
plugin = HybridParallelPlugin(

tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
1414
def __init__(self):
1515
self.is_interleave = False
1616
self.num_layers_per_stage = None
17+
self.num_model_chunks = 1
1718

1819
@property
1920
def num_stages(self):

tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
1414
def __init__(self):
1515
self.is_interleave = False
1616
self.num_layers_per_stage = None
17+
self.num_model_chunks = 1
1718

1819
@property
1920
def num_stages(self):

0 commit comments

Comments
 (0)