Skip to content

Commit 8c3c44c

Browse files
author
wangzaijun
committed
fix
1 parent 7884190 commit 8c3c44c

File tree

3 files changed

+95
-65
lines changed

3 files changed

+95
-65
lines changed

lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,12 @@ def accept_peer_task_loop(
153153
continue
154154

155155
# notify update
156-
notifies_dict = self.transporter.get_new_notifs()
156+
try:
157+
notifies_dict = self.transporter.get_new_notifs()
158+
except BaseException as e:
159+
logger.error(f"get new notifies failed: {str(e)}")
160+
logger.exception(str(e))
161+
notifies_dict = {}
157162

158163
if notifies_dict:
159164
for remote_agent_name, _notify_list in notifies_dict.items():
@@ -220,12 +225,19 @@ def read_peer_kv_loop(self):
220225
self.failed_queue.put(local_trans_task)
221226
continue
222227

223-
224-
xfer_handle = self.transporter.read_blocks_paged(trans_task=local_trans_task)
225-
local_trans_task.xfer_handle = xfer_handle
226-
local_trans_task.start_trans_time = time.time()
227-
with self.update_status_task_list_lock:
228-
self.update_status_task_list.append(local_trans_task)
228+
try:
229+
xfer_handle = self.transporter.read_blocks_paged(trans_task=local_trans_task)
230+
local_trans_task.xfer_handle = xfer_handle
231+
local_trans_task.start_trans_time = time.time()
232+
with self.update_status_task_list_lock:
233+
self.update_status_task_list.append(local_trans_task)
234+
except BaseException as e:
235+
logger.error(f"read_blocks_paged node failed: {local_trans_task.to_str()}")
236+
logger.exception(str(e))
237+
self.transporter.remove_remote_agent(peer_name=local_trans_task.decode_agent_name)
238+
local_trans_task.error_info = f"read_blocks_paged failed: {str(e)}"
239+
self.failed_queue.put(local_trans_task)
240+
continue
229241

230242

231243
@log_exception

lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,17 @@ def connect_add_remote_agent(self, remote_agent: NixlAgentMetadata):
8484

8585
def remove_remote_agent(self, peer_name: str):
8686
if peer_name in self.remote_agents:
87-
self.nixl_agent.remove_remote_agent(peer_name)
88-
remote_agent: NixlAgentMetadata = self.remote_agents.pop(peer_name, None)
89-
if remote_agent.page_xfer_handles is not None:
90-
self.nixl_agent.release_dlist_handle(remote_agent.page_xfer_handles)
87+
try:
88+
remote_agent: NixlAgentMetadata = self.remote_agents.pop(peer_name, None)
89+
assert remote_agent.agent_name == peer_name
90+
self.nixl_agent.remove_remote_agent(remote_agent.agent_name)
91+
if remote_agent.page_xfer_handles is not None:
92+
self.nixl_agent.release_dlist_handle(remote_agent.page_xfer_handles)
93+
except BaseException as e:
94+
logger.error(f"remove remote agent {peer_name} failed")
95+
logger.exception(str(e))
9196
else:
92-
logger.warning(f"peer name {peer_name} agent didnot exist")
97+
logger.warning(f"try to remove remote agent, but peer name {peer_name} agent did not exist")
9398

9499
def send_readtask_to_decode_node(self, trans_task: NIXLChunckedTransTask):
95100
"""
@@ -101,30 +106,28 @@ def send_readtask_to_decode_node(self, trans_task: NIXLChunckedTransTask):
101106
_remote_agent = trans_task.create_decode_agent_obj()
102107
self.connect_add_remote_agent(_remote_agent)
103108

104-
if decode_agent_name in self.remote_agents:
105-
# 将页面读取任务发送给 decode 节点
106-
remote_agent: NixlAgentMetadata = self.remote_agents[decode_agent_name]
107-
assert trans_task.nixl_src_page_index is not None
108-
new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task)
109-
110-
new_trans_task.decode_agent_name = None
111-
new_trans_task.decode_agent_metadata = None
112-
new_trans_task.decode_num_pages = None
113-
new_trans_task.decode_page_reg_desc = None
114-
115-
new_trans_task.prefill_agent_name = self.agent_name
116-
new_trans_task.prefill_agent_metadata = self.agent_metadata
117-
new_trans_task.prefill_num_pages = self.num_pages
118-
new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc
119-
120-
# 不需要传输细节的 mem_indexes 信息
121-
new_trans_task.mem_indexes = None
122-
self.nixl_agent.send_notif(
123-
remote_agent.agent_name,
124-
pickle.dumps(new_trans_task),
125-
)
126-
else:
127-
logger.error(f"decode_agent_name {decode_agent_name} not exist")
109+
110+
# 将页面读取任务发送给 decode 节点
111+
remote_agent: NixlAgentMetadata = self.remote_agents[decode_agent_name]
112+
assert trans_task.nixl_src_page_index is not None
113+
new_trans_task: NIXLChunckedTransTask = copy.copy(trans_task)
114+
115+
new_trans_task.decode_agent_name = None
116+
new_trans_task.decode_agent_metadata = None
117+
new_trans_task.decode_num_pages = None
118+
new_trans_task.decode_page_reg_desc = None
119+
120+
new_trans_task.prefill_agent_name = self.agent_name
121+
new_trans_task.prefill_agent_metadata = self.agent_metadata
122+
new_trans_task.prefill_num_pages = self.num_pages
123+
new_trans_task.prefill_page_reg_desc = self.local_page_mem_desc
124+
125+
# 不需要传输细节的 mem_indexes 信息
126+
new_trans_task.mem_indexes = None
127+
self.nixl_agent.send_notif(
128+
remote_agent.agent_name,
129+
pickle.dumps(new_trans_task),
130+
)
128131
return
129132

130133
def send_notify_to_prefill_node(self, prefill_agent_name: str, notify: bytes):
@@ -144,29 +147,29 @@ def read_blocks_paged(
144147
_remote_agent = trans_task.create_prefill_agent_obj()
145148
self.connect_add_remote_agent(_remote_agent)
146149

147-
if prefill_agent_name in self.remote_agents:
148-
assert trans_task.nixl_src_page_index is not None and trans_task.nixl_dst_page_index is not None
149-
remote_agent: NixlAgentMetadata = self.remote_agents[prefill_agent_name]
150-
src_handle = remote_agent.page_xfer_handles
151-
dst_handle = self.page_local_xfer_handles
152-
notify_obj = NIXLChunckedTransTaskRet(
153-
request_id=trans_task.request_id,
154-
start_kv_index=trans_task.start_kv_index,
155-
end_kv_index=trans_task.end_kv_index,
156-
has_error=False,
157-
error_info=None,
158-
)
159-
handle = self.nixl_agent.make_prepped_xfer(
160-
"READ",
161-
dst_handle,
162-
[trans_task.nixl_dst_page_index],
163-
src_handle,
164-
[trans_task.nixl_src_page_index],
165-
pickle.dumps(notify_obj),
166-
)
167-
self.nixl_agent.transfer(handle)
168-
else:
169-
logger.error(f"prefill_agent_name {prefill_agent_name} not exist")
150+
assert trans_task.nixl_src_page_index is not None and trans_task.nixl_dst_page_index is not None
151+
remote_agent: NixlAgentMetadata = self.remote_agents[prefill_agent_name]
152+
src_handle = remote_agent.page_xfer_handles
153+
dst_handle = self.page_local_xfer_handles
154+
notify_obj = NIXLChunckedTransTaskRet(
155+
request_id=trans_task.request_id,
156+
start_kv_index=trans_task.start_kv_index,
157+
end_kv_index=trans_task.end_kv_index,
158+
has_error=False,
159+
error_info=None,
160+
)
161+
handle = self.nixl_agent.make_prepped_xfer(
162+
"READ",
163+
dst_handle,
164+
[trans_task.nixl_dst_page_index],
165+
src_handle,
166+
[trans_task.nixl_src_page_index],
167+
pickle.dumps(notify_obj),
168+
)
169+
if not handle:
170+
raise RuntimeError(f"make_prepped_xfer failed for task: {trans_task.to_str()}")
171+
172+
self.nixl_agent.transfer(handle)
170173

171174
return handle
172175

@@ -185,8 +188,7 @@ def release_xfer_handle(self, handle):
185188
def shutdown(self):
186189
self.nixl_agent.deregister_memory(self.page_reg_desc)
187190
self.nixl_agent.release_dlist_handle(self.page_local_xfer_handles)
188-
for agent_name, remote_agent in self.remote_agents.items():
189-
self.nixl_agent.remove_remote_agent(remote_agent.agent_name)
190-
if remote_agent.page_xfer_handles is not None:
191-
self.nixl_agent.release_dlist_handle(remote_agent.page_xfer_handles)
191+
agent_names = list(self.remote_agents.keys())
192+
for agent_name in agent_names:
193+
self.remove_remote_agent(agent_name)
192194
return

lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,17 @@ def notify_peer_to_read_kv_loop(self):
151151
sync_event: torch.cuda.Event = sync_event
152152

153153
sync_event.synchronize()
154-
self.transporter.send_readtask_to_decode_node(trans_task=trans_task)
154+
155+
try:
156+
self.transporter.send_readtask_to_decode_node(trans_task=trans_task)
157+
except BaseException as e:
158+
logger.error(f"send readtask to decode node failed: {trans_task.to_str()}")
159+
logger.exception(str(e))
160+
self.transporter.remove_remote_agent(peer_name=trans_task.decode_agent_name)
161+
trans_task.error_info = f"send readtask to decode node failed: {str(e)}"
162+
self.failed_queue.put(trans_task)
163+
continue
164+
155165
logger.info(f"send readtask to decode: {trans_task.to_str()}")
156166

157167
with self.waiting_dict_lock:
@@ -168,7 +178,13 @@ def update_task_status_loop(
168178
continue
169179

170180
# notify update
171-
notifies_dict = self.transporter.get_new_notifs()
181+
try:
182+
notifies_dict = self.transporter.get_new_notifs()
183+
except BaseException as e:
184+
logger.error(f"get new notifies failed: {str(e)}")
185+
logger.exception(str(e))
186+
notifies_dict = {}
187+
172188
if notifies_dict:
173189
for _, _notify_list in notifies_dict.items():
174190
for notify in _notify_list:

0 commit comments

Comments
 (0)