@@ -30,7 +30,7 @@ def __init__(self, args, context, local_visual_port: int):
3030 self .local_visual_port = local_visual_port
3131
3232 self .send_to_visual = None
33- self .remote_vit_instances = []
33+ self .remote_vit_instances = {}
3434 self .current_vit_index = 0
3535 self .remote_vit = args .enable_remote_vit
3636 self .remote_vit_port = args .remote_vit_port
@@ -42,8 +42,7 @@ def _setup_vit_connections(self):
4242 设置VIT连接,支持本地和远程VIT实例
4343 支持多种连接模式:
4444 1. 本地VIT实例 (默认)
45- 2. 远程单个VIT实例
46- 3. 远程多个VIT实例 (负载均衡)
45+ 2. 远程多个VIT实例 (负载均衡)
4746 """
4847 if self .remote_vit :
4948 # 远程VIT实例模式
@@ -57,14 +56,86 @@ def _setup_local_vit_connection(self):
5756 logger .info (f"Connected to local VIT instance at { self .args .zmq_mode } 127.0.0.1:{ self .local_visual_port } " )
5857
5958 def _setup_remote_vit_connections (self ):
60- print ("_setup_remote_vit_connections" , "fdakpgdakgjadpgkjadk" )
61- asyncio .create_task (self .vit_handle_loop ())
59+ """
60+ 初始化远程VIT连接,同步获取初始实例
61+ """
62+ logger .info ("Setting up remote VIT connections..." )
6263
63- # wait for remote vit instances
64- while True :
65- if len (self .remote_vit_instances ) > 0 :
66- break
64+ self ._sync_init_vit_instances ()
65+
66+ retry_count = 0
67+ max_retries = 30 # 最多等待30秒
68+ while len (self .remote_vit_instances ) == 0 and retry_count < max_retries :
69+ logger .info (f"Waiting for VIT instances... (attempt { retry_count + 1 } /{ max_retries } )" )
6770 time .sleep (1 )
71+ retry_count += 1
72+ self ._sync_init_vit_instances ()
73+
74+ if len (self .remote_vit_instances ) == 0 :
75+ logger .warning ("No VIT instances available after initialization" )
76+ else :
77+ logger .info (f"Successfully connected to { len (self .remote_vit_instances )} VIT instances" )
78+
79+ def _sync_init_vit_instances (self ):
80+ """
81+ 同步初始化VIT实例连接
82+ """
83+ try :
84+ # 使用同步方式获取VIT实例
85+ vit_objs = self ._sync_get_vit_objs ()
86+ if vit_objs :
87+ self ._update_vit_connections (vit_objs )
88+ except Exception as e :
89+ logger .error (f"Failed to initialize VIT instances: { e } " )
90+
91+ def _sync_get_vit_objs (self ) -> Optional [Dict [int , VIT_Obj ]]:
92+ """
93+ 同步获取VIT实例信息
94+ """
95+ import requests
96+
97+ uri = f"http://{ self .args .config_server_host } :{ self .args .config_server_port } /registered_visual_objects"
98+ try :
99+ response = requests .get (uri , timeout = 10 )
100+ if response .status_code == 200 :
101+ base64data = response .json ()["data" ]
102+ id_to_vit_obj = pickle .loads (base64 .b64decode (base64data ))
103+ return id_to_vit_obj
104+ else :
105+ logger .error (f"Failed to get VIT instances: { response .status_code } " )
106+ return None
107+ except Exception as e :
108+ logger .error (f"Error getting VIT instances: { e } " )
109+ return None
110+
111+ def _update_vit_connections (self , id_to_vit_obj : Dict [int , VIT_Obj ]):
112+ """
113+ 更新VIT连接,添加新的连接,关闭失效的连接
114+ """
115+ # 关闭不再存在的连接
116+ closed_ids = []
117+ for id , remote_instance in self .remote_vit_instances .items ():
118+ if id not in id_to_vit_obj :
119+ try :
120+ remote_instance .close ()
121+ except :
122+ pass
123+ closed_ids .append (id )
124+ logger .info (f"Closed VIT connection { id } " )
125+
126+ for id in closed_ids :
127+ self .remote_vit_instances .pop (id )
128+
129+ # 建立新的连接
130+ for id , vit_obj in id_to_vit_obj .items ():
131+ if id not in self .remote_vit_instances :
132+ try :
133+ socket = self .context .socket (zmq .PUSH )
134+ socket .connect (f"tcp://{ vit_obj .host_ip_port } :{ self .args .remote_vit_port } " )
135+ self .remote_vit_instances [id ] = socket
136+ logger .info (f"Connected to VIT instance { id } at { vit_obj .host_ip_port } " )
137+ except Exception as e :
138+ logger .error (f"Failed to connect to VIT instance { id } : { e } " )
68139
69140 def _get_vit_instance (self ):
70141 """
@@ -73,10 +144,13 @@ def _get_vit_instance(self):
73144 if not self .remote_vit :
74145 return self .send_to_visual
75146
147+ if len (self .remote_vit_instances ) == 0 :
148+ raise Exception ("No available VIT instances" )
149+
76150 # 简单的轮询负载均衡
77151 index = (self .current_vit_index + 1 ) % len (self .remote_vit_instances )
78152 self .current_vit_index = index
79- return self .remote_vit_instances [index ]
153+ return list ( self .remote_vit_instances . values ()) [index ]
80154
81155 async def send_to_vit (self , data , protocol = pickle .HIGHEST_PROTOCOL ):
82156 """
@@ -86,42 +160,32 @@ async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL):
86160 try :
87161 instance .send_pyobj (data , protocol = protocol )
88162 except Exception as e :
89- logger .error (f"Failed to send to VIT instance { instance .host_ip_port } : { e } " )
90- raise Exception (f"Failed to send to VIT instance { instance .host_ip_port } : { e } " )
163+ logger .error (f"Failed to send to VIT instance: { e } " )
164+ raise Exception (f"Failed to send to VIT instance: { e } " )
165+
166+ await self ._wait_visual_embed_ready ()
91167
92168 async def vit_handle_loop (self ):
93- print ("vit_handle_loop" , "fdakpgdakgjadpgkjadk" )
169+ """
170+ 异步VIT连接管理循环,由外部启动
171+ """
172+ logger .info ("Starting VIT connection management loop" )
94173 while True :
95174 try :
96- id_to_vit_obj = await self ._get_vit_objs ()
97- logger .info (f"get vit_objs { id_to_vit_obj } " )
98- for id , remote_instance in self .remote_vit_instances .items ():
99- if id not in id_to_vit_obj :
100- try :
101- remote_instance [id ].close ()
102- except :
103- pass
104- self .remote_vit_instances .pop (id )
105- logger .info (f"remote vit { id } closed" )
106-
107- for id , vit_obj in id_to_vit_obj .items ():
108- if id not in self .remote_vit_instances :
109- self .remote_vit_instances [id ] = self .context .socket (zmq .PUSH )
110- self .remote_vit_instances [id ].connect (
111- f"tcp://{ vit_obj .host_ip_port } :{ self .args .remote_vit_port } "
112- )
175+ id_to_vit_obj = await self ._async_get_vit_objs ()
176+ if id_to_vit_obj :
177+ logger .debug (f"Retrieved { len (id_to_vit_obj )} VIT instances" )
178+ self ._update_vit_connections (id_to_vit_obj )
113179 await asyncio .sleep (30 )
114180 except Exception as e :
115- logger .exception (str ( e ) )
181+ logger .exception (f"Error in VIT handle loop: { e } " )
116182 await asyncio .sleep (10 )
117183
118- async def _get_vit_objs (self ) -> Optional [Dict [int , VIT_Obj ]]:
184+ async def _async_get_vit_objs (self ) -> Optional [Dict [int , VIT_Obj ]]:
119185 """
120- get_vit_objs 主要负责从 config_server 获取所有的vit远程服务。
186+ 异步获取VIT实例信息
121187 """
122- # 使用 config_server 服务来发现所有的 pd_master 节点。
123188 uri = f"ws://{ self .args .config_server_host } :{ self .args .config_server_port } /registered_visual_objects"
124- print ("uri" , uri )
125189 try :
126190 async with httpx .AsyncClient () as client :
127191 response = await client .get (uri )
@@ -130,9 +194,14 @@ async def _get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]:
130194 id_to_vit_obj = pickle .loads (base64 .b64decode (base64data ))
131195 return id_to_vit_obj
132196 else :
133- logger .error (f"get pd_master_objs error { response .status_code } " )
197+ logger .error (f"Failed to get VIT instances: { response .status_code } " )
134198 return None
135199 except Exception as e :
136- logger .exception (str (e ))
137- await asyncio .sleep (10 )
200+ logger .exception (f"Error getting VIT instances: { e } " )
138201 return None
202+
203+ async def _wait_visual_embed_ready (self ):
204+ """
205+ 等待VIT实例的embed准备好
206+ """
207+ await asyncio .sleep (10 )
0 commit comments