@@ -337,11 +337,12 @@ def __init__(
337
337
else :
338
338
self .pd_disaggregation_mode = "None"
339
339
340
- def set_tp_group (self ):
340
+ def set_communicate_group (self ):
341
341
# different tp group id
342
342
# prevent different tp_groups using the same group_id
343
343
tp_gid_offset = envs .FD_TP_GROUP_GID_OFFSET
344
344
dist .collective ._set_custom_gid (self .data_parallel_rank + tp_gid_offset )
345
+
345
346
self .tp_group = dist .new_group (
346
347
range (
347
348
self .data_parallel_rank * self .tensor_parallel_size ,
@@ -350,8 +351,11 @@ def set_tp_group(self):
350
351
)
351
352
dist .collective ._set_custom_gid (None )
352
353
# same ep group id
353
- dist .collective ._set_custom_gid (self .data_parallel_size + tp_gid_offset )
354
- self .ep_group = dist .new_group (range (self .expert_parallel_size ))
354
+ if self .enable_expert_parallel :
355
+ dist .collective ._set_custom_gid (self .data_parallel_size + tp_gid_offset )
356
+ self .ep_group = dist .new_group (range (self .expert_parallel_size ))
357
+ dist .collective ._set_custom_gid (None )
358
+
355
359
logger .info (
356
360
f"data_parallel_size: { self .data_parallel_size } , tensor_parallel_size: { self .tensor_parallel_size } , expert_parallel_size: { self .expert_parallel_size } , data_parallel_rank: { self .data_parallel_rank } , tensor_parallel_rank: { self .tensor_parallel_rank } , expert_parallel_rank: { self .expert_parallel_rank } , tp_group: { self .tp_group } ."
357
361
)
@@ -830,6 +834,7 @@ class LoadConfig:
830
834
load_strategy: Specifies the weight loading method when enabled:
831
835
- 'ipc': Real-time IPC streaming with automatic resharding
832
836
- 'ipc_snapshot': Load from disk snapshot of IPC weights
837
+ - 'meta': Only model meta messages
833
838
- None: No dynamic loading
834
839
"""
835
840
@@ -840,7 +845,7 @@ def __init__(
840
845
self .load_choices : Union [str , LoadChoices ] = LoadChoices .DEFAULT .value
841
846
self .use_fastsafetensor = int (envs .FD_USE_FASTSAFETENSOR ) == 1
842
847
self .dynamic_load_weight : bool = False
843
- self .load_strategy : Optional [Literal ["ipc" , "ipc_snapshot" ]] = None
848
+ self .load_strategy : Optional [Literal ["ipc" , "ipc_snapshot" , "meta" , "normal" ]] = "normal"
844
849
for key , value in args .items ():
845
850
if hasattr (self , key ):
846
851
setattr (self , key , value )
@@ -1198,12 +1203,10 @@ def __init__(
1198
1203
1199
1204
num_ranks = self .parallel_config .tensor_parallel_size * self .parallel_config .data_parallel_size
1200
1205
self .max_chips_per_node = 16 if current_platform .is_iluvatar () else 8
1201
- if num_ranks > self .max_chips_per_node :
1206
+ if num_ranks > self .max_chips_per_node and self . load_config . load_strategy != "meta" :
1202
1207
self .worker_num_per_node = self .max_chips_per_node
1203
1208
nnode = ceil_div (num_ranks , self .worker_num_per_node )
1204
1209
assert nnode == self .nnode , f"nnode: { nnode } , but got { self .nnode } "
1205
-
1206
- # assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
1207
1210
else :
1208
1211
self .worker_num_per_node = num_ranks
1209
1212
0 commit comments