@@ -84,12 +84,17 @@ def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata):
8484
8585 def remove_remote_agent (self , peer_name : str ):
8686 if peer_name in self .remote_agents :
87- self .nixl_agent .remove_remote_agent (peer_name )
88- remote_agent : NixlAgentMetadata = self .remote_agents .pop (peer_name , None )
89- if remote_agent .page_xfer_handles is not None :
90- self .nixl_agent .release_dlist_handle (remote_agent .page_xfer_handles )
87+ try :
88+ remote_agent : NixlAgentMetadata = self .remote_agents .pop (peer_name , None )
89+ assert remote_agent .agent_name == peer_name
90+ self .nixl_agent .remove_remote_agent (remote_agent .agent_name )
91+ if remote_agent .page_xfer_handles is not None :
92+ self .nixl_agent .release_dlist_handle (remote_agent .page_xfer_handles )
93+ except BaseException as e :
94+ logger .error (f"remove remote agent { peer_name } failed" )
95+ logger .exception (str (e ))
9196 else :
92- logger .warning (f"peer name { peer_name } agent didnot exist" )
97+ logger .warning (f"try to remove remote agent, but peer name { peer_name } agent did not exist" )
9398
9499 def send_readtask_to_decode_node (self , trans_task : NIXLChunckedTransTask ):
95100 """
@@ -101,30 +106,28 @@ def send_readtask_to_decode_node(self, trans_task: NIXLChunckedTransTask):
101106 _remote_agent = trans_task .create_decode_agent_obj ()
102107 self .connect_add_remote_agent (_remote_agent )
103108
104- if decode_agent_name in self .remote_agents :
105- # 将页面读取任务发送给 decode 节点
106- remote_agent : NixlAgentMetadata = self .remote_agents [decode_agent_name ]
107- assert trans_task .nixl_src_page_index is not None
108- new_trans_task : NIXLChunckedTransTask = copy .copy (trans_task )
109-
110- new_trans_task .decode_agent_name = None
111- new_trans_task .decode_agent_metadata = None
112- new_trans_task .decode_num_pages = None
113- new_trans_task .decode_page_reg_desc = None
114-
115- new_trans_task .prefill_agent_name = self .agent_name
116- new_trans_task .prefill_agent_metadata = self .agent_metadata
117- new_trans_task .prefill_num_pages = self .num_pages
118- new_trans_task .prefill_page_reg_desc = self .local_page_mem_desc
119-
120- # 不需要传输细节的 mem_indexes 信息
121- new_trans_task .mem_indexes = None
122- self .nixl_agent .send_notif (
123- remote_agent .agent_name ,
124- pickle .dumps (new_trans_task ),
125- )
126- else :
127- logger .error (f"decode_agent_name { decode_agent_name } not exist" )
109+
110+ # 将页面读取任务发送给 decode 节点
111+ remote_agent : NixlAgentMetadata = self .remote_agents [decode_agent_name ]
112+ assert trans_task .nixl_src_page_index is not None
113+ new_trans_task : NIXLChunckedTransTask = copy .copy (trans_task )
114+
115+ new_trans_task .decode_agent_name = None
116+ new_trans_task .decode_agent_metadata = None
117+ new_trans_task .decode_num_pages = None
118+ new_trans_task .decode_page_reg_desc = None
119+
120+ new_trans_task .prefill_agent_name = self .agent_name
121+ new_trans_task .prefill_agent_metadata = self .agent_metadata
122+ new_trans_task .prefill_num_pages = self .num_pages
123+ new_trans_task .prefill_page_reg_desc = self .local_page_mem_desc
124+
125+ # 不需要传输细节的 mem_indexes 信息
126+ new_trans_task .mem_indexes = None
127+ self .nixl_agent .send_notif (
128+ remote_agent .agent_name ,
129+ pickle .dumps (new_trans_task ),
130+ )
128131 return
129132
130133 def send_notify_to_prefill_node (self , prefill_agent_name : str , notify : bytes ):
@@ -144,29 +147,29 @@ def read_blocks_paged(
144147 _remote_agent = trans_task .create_prefill_agent_obj ()
145148 self .connect_add_remote_agent (_remote_agent )
146149
147- if prefill_agent_name in self . remote_agents :
148- assert trans_task . nixl_src_page_index is not None and trans_task . nixl_dst_page_index is not None
149- remote_agent : NixlAgentMetadata = self . remote_agents [ prefill_agent_name ]
150- src_handle = remote_agent . page_xfer_handles
151- dst_handle = self . page_local_xfer_handles
152- notify_obj = NIXLChunckedTransTaskRet (
153- request_id = trans_task .request_id ,
154- start_kv_index = trans_task .start_kv_index ,
155- end_kv_index = trans_task . end_kv_index ,
156- has_error = False ,
157- error_info = None ,
158- )
159- handle = self . nixl_agent . make_prepped_xfer (
160- "READ" ,
161- dst_handle ,
162- [ trans_task . nixl_dst_page_index ] ,
163- src_handle ,
164- [ trans_task . nixl_src_page_index ] ,
165- pickle . dumps ( notify_obj ),
166- )
167- self . nixl_agent . transfer ( handle )
168- else :
169- logger . error ( f"prefill_agent_name { prefill_agent_name } not exist" )
150+ assert trans_task . nixl_src_page_index is not None and trans_task . nixl_dst_page_index is not None
151+ remote_agent : NixlAgentMetadata = self . remote_agents [ prefill_agent_name ]
152+ src_handle = remote_agent . page_xfer_handles
153+ dst_handle = self . page_local_xfer_handles
154+ notify_obj = NIXLChunckedTransTaskRet (
155+ request_id = trans_task . request_id ,
156+ start_kv_index = trans_task .start_kv_index ,
157+ end_kv_index = trans_task .end_kv_index ,
158+ has_error = False ,
159+ error_info = None ,
160+ )
161+ handle = self . nixl_agent . make_prepped_xfer (
162+ "READ" ,
163+ dst_handle ,
164+ [ trans_task . nixl_dst_page_index ] ,
165+ src_handle ,
166+ [ trans_task . nixl_src_page_index ] ,
167+ pickle . dumps ( notify_obj ) ,
168+ )
169+ if not handle :
170+ raise RuntimeError ( f"make_prepped_xfer failed for task: { trans_task . to_str () } " )
171+
172+ self . nixl_agent . transfer ( handle )
170173
171174 return handle
172175
@@ -185,8 +188,7 @@ def release_xfer_handle(self, handle):
185188 def shutdown (self ):
186189 self .nixl_agent .deregister_memory (self .page_reg_desc )
187190 self .nixl_agent .release_dlist_handle (self .page_local_xfer_handles )
188- for agent_name , remote_agent in self .remote_agents .items ():
189- self .nixl_agent .remove_remote_agent (remote_agent .agent_name )
190- if remote_agent .page_xfer_handles is not None :
191- self .nixl_agent .release_dlist_handle (remote_agent .page_xfer_handles )
191+ agent_names = list (self .remote_agents .keys ())
192+ for agent_name in agent_names :
193+ self .remove_remote_agent (agent_name )
192194 return
0 commit comments