@@ -43,6 +43,7 @@ def __init__(self, cfg, engine, dp_rank):
43
43
target = self ._response_external_module_control_instruct , daemon = True
44
44
)
45
45
self .response_external_instruct_thread .start ()
46
+ self .response_lock = threading .Lock () # prevent to call send_multipart in zmq concurrently
46
47
47
48
def _get_current_server_info (self ):
48
49
"""
@@ -76,7 +77,8 @@ def _recv_external_module_control_instruct(self):
76
77
payload_info = self ._get_current_server_info ()
77
78
result = {"task_id" : task_id_str , "result" : payload_info }
78
79
logger .info (f"Response for task: { task_id_str } " )
79
- self .recv_control_cmd_server .response_for_control_cmd (task_id_str , result )
80
+ with self .response_lock :
81
+ self .recv_control_cmd_server .response_for_control_cmd (task_id_str , result )
80
82
81
83
elif task ["cmd" ] == "get_metrics" :
82
84
metrics_text = get_filtered_metrics (
@@ -85,7 +87,8 @@ def _recv_external_module_control_instruct(self):
85
87
)
86
88
result = {"task_id" : task_id_str , "result" : metrics_text }
87
89
logger .info (f"Response for task: { task_id_str } " )
88
- self .recv_control_cmd_server .response_for_control_cmd (task_id_str , result )
90
+ with self .response_lock :
91
+ self .recv_control_cmd_server .response_for_control_cmd (task_id_str , result )
89
92
elif task ["cmd" ] == "connect_rdma" :
90
93
self .engine .engine_worker_queue .put_connect_rdma_task (task )
91
94
@@ -100,7 +103,8 @@ def _response_external_module_control_instruct(self):
100
103
task_id_str = result_data ["task_id" ]
101
104
result = {"task_id" : task_id_str , "result" : result_data }
102
105
logger .info (f"Response for task: { task_id_str } " )
103
- self .recv_control_cmd_server .response_for_control_cmd (task_id_str , result )
106
+ with self .response_lock :
107
+ self .recv_control_cmd_server .response_for_control_cmd (task_id_str , result )
104
108
else :
105
109
time .sleep (0.001 )
106
110
except Exception as e :
0 commit comments