@@ -27,16 +27,18 @@ def __init__(
27
27
pipeline_axis : int ,
28
28
enable_interleave : bool = False ,
29
29
num_model_chunks : int = 1 ,
30
+ num_layers_per_stage : Optional [List [int ]] = None ,
30
31
) -> None :
31
32
assert enable_interleave or num_model_chunks == 1 , "num_model_chunks must be 1 when enable_interleave is False"
32
33
33
- self .num_layers_per_stage = None
34
-
35
34
self .pg_mesh = pg_mesh
36
35
self .pipeline_axis = pipeline_axis
37
36
self .prev_rank : Optional [Tuple [int , ...]] = None
38
37
self .next_rank : Optional [Tuple [int , ...]] = None
39
38
self .p2p_groups : Dict [Tuple [int , int ], ProcessGroup ] = {}
39
+ if num_layers_per_stage is not None :
40
+ assert len (num_layers_per_stage ) == self .num_stages
41
+ self .num_layers_per_stage = num_layers_per_stage
40
42
41
43
# init prev and next coord
42
44
coord = self .pg_mesh .coordinate ()
@@ -56,6 +58,8 @@ def __init__(
56
58
self .p2p_groups [tuple (ranks_in_group )] = group
57
59
58
60
self .is_interleave = enable_interleave
61
+ # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
62
+ self .num_model_chunks : int = num_model_chunks
59
63
if enable_interleave :
60
64
# use circle p2p communication
61
65
# add the process group of the first rank and the last rank
@@ -64,59 +68,11 @@ def __init__(
64
68
ranks_in_group = self .pg_mesh .get_ranks_in_group (group )
65
69
self .p2p_groups [tuple (ranks_in_group )] = group
66
70
67
- # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
68
- self .num_model_chunks : int = num_model_chunks
69
-
70
71
# for shardformer, hold stage indices of model
71
72
self .stage_indices : List [Tuple [int , int ]]
72
73
# for shardformer, hold model chunk id
73
74
self .model_chunk_id : Optional [int ] = None
74
75
75
- @property
76
- def control_distribute_layers (self ) -> bool :
77
- return self .num_layers_per_stage is not None
78
-
79
- def set_distribution_config (self , num_model_layers : int , num_layers_per_stage : List [int ]) -> None :
80
- """Set the distribution configuration.
81
- This allows user to customize the number of layers for each stage.
82
-
83
- Args:
84
- num_model_layers (int): Number of layers in the model.
85
- num_layers_per_stage (List[int]): Number of layers for each stage.
86
- """
87
- assert all ([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage ])
88
- assert sum (num_layers_per_stage ) == num_model_layers
89
- assert len (num_layers_per_stage ) == self .num_stages * (self .num_model_chunks if self .is_interleave else 1 )
90
- self .num_model_layers = num_model_layers
91
- self .num_layers_per_stage = num_layers_per_stage
92
-
93
- def distribute_layers (
94
- self , num_layers : int , num_stages : Optional [int ] = None , num_model_chunks : Optional [int ] = None
95
- ) -> List [int ]:
96
- """Divide layers into stages"""
97
- num_stages = self .num_stages if num_stages is None else num_stages
98
- num_model_chunks = (
99
- (self .num_model_chunks if self .is_interleave else 1 ) if num_model_chunks is None else num_model_chunks
100
- )
101
-
102
- if self .control_distribute_layers :
103
- assert num_layers == self .num_model_layers
104
- return self .num_layers_per_stage
105
-
106
- else :
107
- quotient = num_layers // (num_stages * num_model_chunks )
108
- remainder = num_layers % (num_stages * num_model_chunks )
109
-
110
- # calculate the num_layers per stage
111
- layers_per_stage = [quotient ] * num_stages * num_model_chunks
112
-
113
- # deal with the rest layers
114
- if remainder > 0 :
115
- start_position = (num_stages * num_model_chunks ) // 2 - remainder // 2
116
- for i in range (start_position , start_position + remainder ):
117
- layers_per_stage [i ] += 1
118
- return layers_per_stage
119
-
120
76
def get_stage_index (
121
77
self ,
122
78
layers_per_stage : List [int ],
@@ -139,9 +95,7 @@ def get_stage_index(
139
95
140
96
"""
141
97
stage = self .stage if stage is None else stage
142
- num_model_chunks = (
143
- (self .num_model_chunks if self .is_interleave else 1 ) if num_model_chunks is None else num_model_chunks
144
- )
98
+ num_model_chunks = self .num_model_chunks if num_model_chunks is None else num_model_chunks
145
99
num_stages = self .num_stages if num_stages is None else num_stages
146
100
147
101
num_layers_per_stage_accumulated = np .insert (np .cumsum (layers_per_stage ), 0 , 0 )
@@ -261,3 +215,25 @@ def switch_model_chunk_id(self, model_chunk_id: int):
261
215
self .model_chunk_id = model_chunk_id
262
216
yield
263
217
self .model_chunk_id = old_model_chunk_id
218
+
219
+ def distribute_layers (
220
+ self , num_layers : int , num_stages : Optional [int ] = None , num_model_chunks : Optional [int ] = None
221
+ ) -> List [int ]:
222
+ if self .num_layers_per_stage is not None :
223
+ assert sum (self .num_layers_per_stage ) == num_layers
224
+ return self .num_layers_per_stage
225
+
226
+ num_stages = self .num_stages if num_stages is None else num_stages
227
+ num_model_chunks = self .num_model_chunks if num_model_chunks is None else num_model_chunks
228
+ quotient = num_layers // (num_stages * num_model_chunks )
229
+ remainder = num_layers % (num_stages * num_model_chunks )
230
+
231
+ # calculate the num_layers per stage
232
+ layers_per_stage = [quotient ] * num_stages * num_model_chunks
233
+
234
+ # deal with the rest layers
235
+ if remainder > 0 :
236
+ start_position = (num_stages * num_model_chunks ) // 2 - remainder // 2
237
+ for i in range (start_position , start_position + remainder ):
238
+ layers_per_stage [i ] += 1
239
+ return layers_per_stage
0 commit comments