1717
1818
1919class InferStateLock :
20- def __init__ (self , name ):
20+ def __init__ (self , name , rank_in_dp : int , dp_rank_in_node : int , dp_world_size : int ):
2121 self .infer_lock = threading .Lock ()
22+ self .dp_rank_in_node = dp_rank_in_node
23+ # sync_world_size 应该是 min(dp_world_size, node_world_size)
24+ self .dp_world_size = dp_world_size
25+ self .rank_in_dp = rank_in_dp
2226 # 默认开 128 tp 的空间, 现在应该没什么卡能开这么大的tp 吧
23- self .lock_tp_infos = SharedArray (f"{ name } _lock_tp_infos" , shape = (129 ,), dtype = np .int64 )
27+ self .lock_tp_infos = SharedArray (f"{ name } _dp_rank_ { str ( self . dp_rank_in_node ) } _lock_tp_infos" , shape = (self . dp_world_size + 1 ,), dtype = np .int64 )
2428 self .lock_tp_infos .arr [:] = 0
25- self .rank_id = dist .get_rank ()
26- self .world_size = dist .get_world_size ()
2729
2830 def add_cur_mark (self ):
29- self .lock_tp_infos .arr [self .rank_id ] += 1
31+ self .lock_tp_infos .arr [self .rank_in_dp ] += 1
3032
3133 def get_cur_mark (self ):
32- return self .lock_tp_infos .arr [self .rank_id ]
34+ return self .lock_tp_infos .arr [self .rank_in_dp ]
3335
3436 def get_max_mark_in_group (self ):
35- return np .max (self .lock_tp_infos .arr [0 : self .world_size ])
37+ return np .max (self .lock_tp_infos .arr [0 : self .dp_world_size ])
3638
3739 def judge_cur_mark_equal_max_mark_in_group (self ):
3840 return self .get_cur_mark () == self .get_max_mark_in_group ()
3941
4042 def judge_mark_in_group_all_same (self ):
41- marks = self .lock_tp_infos .arr [0 : self .world_size ]
43+ marks = self .lock_tp_infos .arr [0 : self .dp_world_size ]
4244 return bool (np .all (marks == marks [0 ]))
4345
4446 def acquire_lock_and_update_cur_mark (self ):
@@ -49,11 +51,11 @@ def release_lock(self):
4951 self .infer_lock .release ()
5052
5153 def set_group_wait_mark (self ):
52- if self .rank_id == 0 :
54+ if self .rank_in_dp == 0 :
5355 self .lock_tp_infos .arr [- 1 ] = 1
5456
5557 def unset_group_wait_mark (self ):
56- if self .rank_id == 0 :
58+ if self .rank_in_dp == 0 :
5759 self .lock_tp_infos .arr [- 1 ] = 0
5860
5961 def get_group_wait_mark (self ):
@@ -63,7 +65,7 @@ def get_group_wait_mark(self):
6365@dataclass
6466class G_Infer_Lock :
6567 obj : InferStateLock = None
66- dp_size : int = None
68+ dp_world_size : int = None
6769
6870 def acquire (self ):
6971 if self .obj is not None :
@@ -86,9 +88,8 @@ def release(self):
8688
8789# 下面两个函数需要配对使用
8890def acquire_lock_until_ready (nccl_group ):
89- # 在 deepseekv2 的tp dp 混合运行模式下, 不需要多个推理进程间做协调同步
90- # 所以直接加锁,解锁即可
91- if g_infer_state_lock .dp_size != 1 :
91+ # 单卡一tp不用过度加锁
92+ if g_infer_state_lock .dp_world_size == 1 :
9293 g_infer_state_lock .obj .infer_lock .acquire ()
9394 return
9495
@@ -118,7 +119,7 @@ def release_acquired_lock():
118119@dataclass
119120class G_Router_Lock :
120121 """
121- 保护pd分离模式下, 一些数据的操作 。
122+ 保护pd分离模式下, 一些调度相关信息数据的操作 。
122123 """
123124
124125 obj = None # 进程锁对象
0 commit comments