@@ -83,12 +83,18 @@ class PDTransJoinInfo:
8383 prefill_device_id : int
8484 pd_prefill_nccl_ip : str
8585 pd_prefill_nccl_port : int
86+ # 用于标识一次唯一的连接,prefill_id 和 decode_id 相同时,可能因为网络原因重连,为了更好的区分
87+ # 一次连接,使用一个 uuid 为其标识
88+ connect_id : str
8689
8790
8891@dataclass
8992class PDTransLeaveInfo :
9093 decode_id : int
9194 prefill_id : int
95+ # 用于标识一次唯一的连接,prefill_id 和 decode_id 相同时,可能因为网络原因重连,为了更好的区分
96+ # 一次连接,使用一个 uuid 为其标识
97+ connect_id : str
9298
9399
94100@dataclass
@@ -106,6 +112,8 @@ class KVMoveTask:
106112 prefill_dp_index : int
107113 decode_dp_index : int
108114 mark_start_time : float = None
115+ # 标记任务使用某个连接id进行传输
116+ connect_id : str = None
109117
110118 def __post_init__ (self ):
111119 if len (self .input_tokens ) <= 0 :
@@ -118,14 +126,14 @@ def to_prefill_log_info(self):
118126 d_i = self .prefill_dp_index
119127 id = self .group_request_id
120128 log = f"id: { id } in_len:{ len (self .input_tokens )} v_len: { v_len } move_len: { self .move_kv_len } dp_index:{ d_i } "
121- return log
129+ return log + f" connect_id: { self . connect_id } "
122130
123131 def to_decode_log_info (self ):
124132 v_len = None if self .decode_token_indexes is None else len (self .decode_token_indexes )
125133 d_i = self .decode_dp_index
126134 id = self .group_request_id
127135 log = f"id: { id } in_len:{ len (self .input_tokens )} v_len: { v_len } move_len: { self .move_kv_len } dp_index:{ d_i } "
128- return log
136+ return log + f" connect_id: { self . connect_id } "
129137
130138 def id (self ):
131139 return self .group_request_id
@@ -135,3 +143,8 @@ def get_cost_time(self):
135143 return time .time () - self .mark_start_time
136144 else :
137145 return 100000000000
146+
147+ @dataclass
148+ class KVMoveTaskGroup :
149+ tasks : List [KVMoveTask ]
150+ connect_id : str
0 commit comments