1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+ from enum import IntEnum
1516from typing import List
1617
1718import torch
1819
1920
21+ class CpType (IntEnum ):
22+ # CP type for ulysses parallelism
23+ ULYSSES = 0
24+ # CP type for star attention
25+ STAR = 1
26+ # CP type for ring attention
27+ RING = 2
28+ # CP type for helix parallelism
29+ HELIX = 3
30+
31+
2032class Mapping (object ):
2133 '''
2234 A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2
@@ -135,58 +147,70 @@ def __init__(
135147 if moe_cluster_size == - 1 :
136148 moe_cluster_size = 1
137149
150+ cp_type = CpType .ULYSSES if cp_config is None else cp_config .get (
151+ "cp_type" , CpType .ULYSSES )
152+ moe_world_size = tp_size if cp_type == CpType .ULYSSES else tp_size * cp_size
153+
138154 if moe_tp_size == - 1 and moe_ep_size == - 1 :
139- moe_tp_size = tp_size // moe_cluster_size
155+ moe_tp_size = moe_world_size // moe_cluster_size
140156 moe_ep_size = 1
141157
142158 elif moe_tp_size == - 1 :
143- moe_tp_size = tp_size // (moe_ep_size * moe_cluster_size )
159+ moe_tp_size = moe_world_size // (moe_ep_size * moe_cluster_size )
144160
145161 elif moe_ep_size == - 1 :
146- moe_ep_size = tp_size // (moe_tp_size * moe_cluster_size )
162+ moe_ep_size = moe_world_size // (moe_tp_size * moe_cluster_size )
147163
148164 if attn_tp_size == - 1 and attn_cp_size == - 1 :
149- # fallback to ulysses
150- attn_tp_size = tp_size * cp_size
151- attn_cp_size = 1
165+ if cp_type == CpType .ULYSSES :
166+ # fallback to ulysses
167+ attn_tp_size = tp_size * cp_size
168+ attn_cp_size = 1
169+ else :
170+ # fallback to helix
171+ attn_tp_size = tp_size
172+ attn_cp_size = cp_size
152173
153174 elif attn_tp_size == - 1 :
154- attn_tp_size = cp_size * tp_size // attn_cp_size
175+ attn_tp_size = ( tp_size * cp_size ) // attn_cp_size
155176
156177 elif attn_cp_size == - 1 :
157- attn_cp_size = cp_size * tp_size // attn_tp_size
178+ attn_cp_size = ( tp_size * cp_size ) // attn_tp_size
158179
159- if attn_cp_size != 1 :
180+ if attn_cp_size != 1 and cp_type == CpType . ULYSSES :
160181 raise ValueError (
161- f"attn_cp_size must be 1 for now, but got { attn_tp_size } , { attn_cp_size } ."
182+ f"attn_cp_size must be 1 for now for ulysses , but got { attn_tp_size } , { attn_cp_size } ."
162183 )
163184
164185 if auto_parallel :
165- if tp_size != 1 or pp_size != 1 or tp_size != 1 :
186+ if tp_size != 1 or pp_size != 1 or cp_size != 1 :
166187 raise ValueError (
167- f "When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got { tp_size } , { pp_size } , { cp_size } . "
168- )
188+ "When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, "
189+ f"but got { tp_size } , { pp_size } , { cp_size } ." )
169190 else :
170191 if tp_size * pp_size * cp_size != world_size :
171192 raise ValueError (
172- f"world_size must equal to tp_size * pp_size * cp_size, but got { world_size } != { tp_size } * { pp_size } * { cp_size } ."
193+ "world_size must equal to tp_size * pp_size * cp_size, "
194+ f"but got { world_size } != { tp_size } * { pp_size } * { cp_size } ."
173195 )
174196
175197 moe_tp_ep_size = moe_tp_size * moe_ep_size
176198 moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size
177- if moe_tp_cluster_ep_size != tp_size :
199+ if moe_tp_cluster_ep_size != moe_world_size :
178200 raise ValueError (
179- f"tp_size must equal to moe_tp_size * moe_ep_size * moe_cluster_size, but got { tp_size } != { moe_tp_size } * { moe_ep_size } * { moe_cluster_size } "
180- )
201+ " moe_tp_size * moe_ep_size * moe_cluster_size must equal to moe_world_size, "
202+ f"but got { moe_tp_cluster_ep_size } != { moe_world_size } " )
181203
182204 attn_tp_cp_size = attn_tp_size * attn_cp_size
183205 if attn_tp_cp_size != tp_size * cp_size :
184206 raise ValueError (
185- f"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, but got { tp_size } * { cp_size } != { attn_tp_size } * { attn_cp_size } "
207+ "tp_size * cp_size must equal to attn_tp_size * attn_cp_size, "
208+ f"but got { tp_size } * { cp_size } != { attn_tp_size } * { attn_cp_size } "
186209 )
187210
188- if moe_ep_size != 1 and cp_size > 1 :
189- raise NotImplementedError ("CP don't support MoE tp/ep yet" )
211+ if moe_ep_size != 1 and cp_size > 1 and cp_type != CpType .HELIX :
212+ raise NotImplementedError (
213+ f"CP { cp_type } doesn't support MoE tp/ep yet" )
190214
191215 self .tp_size = tp_size
192216 self .cp_size = cp_size
@@ -275,6 +299,7 @@ def __eq__(self, other):
275299 and self .moe_ep_size == other .moe_ep_size
276300 and self .attn_tp_size == other .attn_tp_size
277301 and self .attn_cp_size == other .attn_cp_size
302+ and self .cp_config == other .cp_config
278303 and self .auto_parallel == other .auto_parallel )
279304
280305 def __hash__ (self ):
@@ -290,6 +315,8 @@ def __hash__(self):
290315 self .moe_ep_size ,
291316 self .attn_tp_size ,
292317 self .attn_cp_size ,
318+ # note: we do not allow updating cp_config after initialization
319+ tuple (sorted (self .cp_config .items ())),
293320 self .auto_parallel ,
294321 ))
295322
@@ -376,8 +403,13 @@ def local_rank(self):
376403 def dp_size (self ):
377404 return self .tp_size if self .enable_attention_dp else 1
378405
379- def has_cp (self ):
380- return self .cp_size > 1
406+ def has_cp_ulysses (self ):
407+ return self .cp_size > 1 and self .cp_config .get (
408+ "cp_type" ) == CpType .ULYSSES
409+
410+ def has_cp_helix (self ):
411+ return self .cp_size > 1 and self .cp_config .get (
412+ "cp_type" ) == CpType .HELIX
381413
382414 def get_node_rank (self , rank : int ):
383415 return rank // self .gpus_per_node
@@ -415,6 +447,29 @@ def next_pp_rank(self):
415447 p = p - self .world_size
416448 return p
417449
450+ def is_last_cp_rank (self ):
451+ return self .cp_rank == self .cp_size - 1
452+
453+ def is_first_cp_rank (self ):
454+ return self .cp_rank == 0
455+
456+ def has_cp (self ):
457+ return self .cp_size > 1
458+
459+ def prev_cp_rank (self ):
460+ p = self .rank - self .tp_size
461+ if p // (self .tp_size * self .cp_size ) < self .rank // (self .tp_size *
462+ self .cp_size ):
463+ return p + self .tp_size * self .cp_size
464+ return p
465+
466+ def next_cp_rank (self ):
467+ p = self .rank + self .tp_size
468+ if p // (self .tp_size * self .cp_size ) > self .rank // (self .tp_size *
469+ self .cp_size ):
470+ return p - self .tp_size * self .cp_size
471+ return p
472+
418473 def has_moe_cluster (self ):
419474 return self .moe_cluster_size > 1
420475
@@ -453,5 +508,6 @@ def to_dict(self):
453508 'moe_ep_size' : self .moe_ep_size ,
454509 'attn_tp_size' : self .attn_tp_size ,
455510 'attn_cp_size' : self .attn_cp_size ,
511+ 'cp_config' : self .cp_config ,
456512 'auto_parallel' : self .auto_parallel ,
457513 }
0 commit comments