33import logging
44import random
55import uuid
6+ from collections import namedtuple
67from json import JSONDecodeError
78from os import getenv
89from typing import Any , Callable , List , Optional
910
1011from AnyQt .QtCore import QSettings
1112from httpx import AsyncClient , NetworkError , ReadTimeout , Response
1213
13- from Orange .misc .utils .embedder_utils import (EmbedderCache ,
14- EmbeddingCancelledException ,
15- EmbeddingConnectionError ,
16- get_proxies )
14+ from Orange .misc .utils .embedder_utils import (
15+ EmbedderCache ,
16+ EmbeddingCancelledException ,
17+ EmbeddingConnectionError ,
18+ get_proxies ,
19+ )
1720
1821log = logging .getLogger (__name__ )
22+ TaskItem = namedtuple ("TaskItem" , ("id" , "item" , "no_repeats" ))
1923
2024
2125class ServerEmbedderCommunicator :
2226 """
2327 This class needs to be inherited by the class which re-implements
24- _encode_data_instance and defines self.content_type. For sending a table
25- with data items use embedd_table function. This one is called with the
26- complete Orange data Table. Then _encode_data_instance needs to extract
27- data to be embedded from the RowInstance. For images, it takes the image
28- path from the table, load image, and transform it into bytes.
28+ _encode_data_instance and defines self.content_type. For sending a list
29+ with data items use embedd_table function.
2930
3031 Attributes
3132 ----------
@@ -69,14 +70,14 @@ def __init__(
6970 ) or str (uuid .getnode ())
7071 except TypeError :
7172 self .machine_id = str (uuid .getnode ())
72- self .session_id = str (random .randint (1 , 1e10 ))
73+ self .session_id = str (random .randint (1 , int ( 1e10 ) ))
7374
7475 self ._cache = EmbedderCache (model_name )
7576
7677 # default embedding timeouts are too small we need to increase them
7778 self .timeout = 180
78- self .num_parallel_requests = 0
79- self . max_parallel = max_parallel_requests
79+ self .max_parallel_requests = max_parallel_requests
80+
8081 self .content_type = None # need to be set in a class inheriting
8182
8283 def embedd_data (
@@ -111,8 +112,7 @@ def embedd_data(
111112 EmbeddingCancelledException:
112113 If cancelled attribute is set to True (default=False).
113114 """
114- # if there is less items than 10 connection error should be raised
115- # earlier
115+ # if there is less items than 10 connection error should be raised earlier
116116 self .max_errors = min (len (data ) * self .MAX_REPEATS , 10 )
117117
118118 loop = asyncio .new_event_loop ()
@@ -122,10 +122,10 @@ def embedd_data(
122122 self .embedd_batch (data , processed_callback )
123123 )
124124 except Exception :
125- loop .close ()
126125 raise
126+ finally :
127+ loop .close ()
127128
128- loop .close ()
129129 return embeddings
130130
131131 async def embedd_batch (
@@ -153,32 +153,63 @@ async def embedd_batch(
153153 EmbeddingCancelledException:
154154 If cancelled attribute is set to True (default=False).
155155 """
156- requests = []
156+ results = [None ] * len (data )
157+ queue = asyncio .Queue ()
158+
159+ # fill the queue with items to embedd
160+ for i , item in enumerate (data ):
161+ queue .put_nowait (TaskItem (id = i , item = item , no_repeats = 0 ))
162+
157163 async with AsyncClient (
158- timeout = self .timeout , base_url = self .server_url , proxies = get_proxies ()
164+ timeout = self .timeout , base_url = self .server_url , proxies = get_proxies ()
159165 ) as client :
160- for p in data :
161- if self ._cancelled :
162- raise EmbeddingCancelledException ()
163- requests .append (self ._send_to_server (p , client , proc_callback ))
166+ tasks = self ._init_workers (client , queue , results , proc_callback )
167+
168+ # wait for the queue to complete or one of workers to exit
169+ queue_complete = asyncio .create_task (queue .join ())
170+ await asyncio .wait (
171+ [queue_complete , * tasks ], return_when = asyncio .FIRST_COMPLETED
172+ )
173+
174+ # Cancel worker tasks when done
175+ queue_complete .cancel ()
176+ await self ._cancel_workers (tasks )
164177
165- embeddings = await asyncio .gather (* requests )
166178 self ._cache .persist_cache ()
167- assert self . num_parallel_requests == 0
179+ return results
168180
169- return embeddings
181+ def _init_workers (self , client , queue , results , callback ):
182+ """Init required number of workers"""
183+ t = [
184+ asyncio .create_task (self ._send_to_server (client , queue , results , callback ))
185+ for _ in range (self .max_parallel_requests )
186+ ]
187+ log .debug (f"Created { self .max_parallel_requests } workers" )
188+ return t
170189
171- async def __wait_until_released (self ) -> None :
172- while self .num_parallel_requests >= self .max_parallel :
173- await asyncio .sleep (0.1 )
190+ @staticmethod
191+ async def _cancel_workers (tasks ):
192+ """Cancel worker at the end"""
193+ log .debug (f"Canceling workers" )
194+ try :
195+ # try to catch any potential exceptions
196+ await asyncio .gather (* tasks )
197+ except Exception as ex :
198+ # raise exceptions gathered from an failed worker
199+ raise ex
200+ finally :
201+ # cancel all tasks in both cases
202+ for task in tasks :
203+ task .cancel ()
204+ # Wait until all worker tasks are cancelled.
205+ await asyncio .gather (* tasks , return_exceptions = True )
206+ log .debug (f"All workers canceled" )
174207
175208 def __check_cancelled (self ):
176209 if self ._cancelled :
177210 raise EmbeddingCancelledException ()
178211
179- async def _encode_data_instance (
180- self , data_instance : Any
181- ) -> Optional [bytes ]:
212+ async def _encode_data_instance (self , data_instance : Any ) -> Optional [bytes ]:
182213 """
183214 The reimplementation of this function must implement the procedure
184215 to encode the data item in a string format that will be sent to the
@@ -197,63 +228,73 @@ async def _encode_data_instance(
197228 raise NotImplementedError
198229
199230 async def _send_to_server (
200- self ,
201- data_instance : Any ,
202- client : AsyncClient ,
203- proc_callback : Callable [[bool ], None ] = None ,
204- ) -> Optional [List [float ]]:
231+ self ,
232+ client : AsyncClient ,
233+ queue : asyncio .Queue ,
234+ results : List ,
235+ proc_callback : Callable [[bool ], None ] = None ,
236+ ):
205237 """
206- Function get an data instance . It extract data from it and send them to
207- server and retrieve responses.
238+ Worker that embedds data. It is pulling items from the until the queue
239+ is empty. It is canceled by embedd_batch all tasks are finished
208240
209241 Parameters
210242 ----------
211- data_instance
212- Single row of the input table.
213243 client
214244 HTTPX client that communicates with the server
245+ queue
246+ The queue with items of type TaskItem to be embedded
247+ results
248+ The list to append results in. The list has length equal to numbers
249+ of all items to embedd. The result need to be inserted at the index
250+ defined in queue items.
215251 proc_callback
216252 A function that is called after each item is fully processed
217253 by either getting a successful response from the server,
218254 getting the result from cache or skipping the item.
219-
220- Returns
221- -------
222- Embedding. For items that are not successfully embedded returns None.
223255 """
224- await self .__wait_until_released ()
225- self .__check_cancelled ()
226-
227- self .num_parallel_requests += 1
228- # load bytes
229- data_bytes = await self ._encode_data_instance (data_instance )
230- if data_bytes is None :
231- self .num_parallel_requests -= 1
232- return None
256+ while not queue .empty ():
257+ self .__check_cancelled ()
258+
259+ # get item from the queue
260+ i , data_instance , num_repeats = await queue .get ()
261+
262+ # load bytes
263+ data_bytes = await self ._encode_data_instance (data_instance )
264+ if data_bytes is None :
265+ continue
233266
234- # if data in cache return it
235- cache_key = self ._cache .md5_hash (data_bytes )
236- emb = self ._cache .get_cached_result_or_none (cache_key )
267+ # retrieve embedded item from the local cache
268+ cache_key = self ._cache .md5_hash (data_bytes )
269+ log .debug (f"Embedding { cache_key } " )
270+ emb = self ._cache .get_cached_result_or_none (cache_key )
237271
238- if emb is None :
239- # in case that embedding not sucessfull resend it to the server
240- # maximally for MAX_REPEATS time
241- for i in range (1 , self .MAX_REPEATS + 1 ):
242- self .__check_cancelled ()
272+ if emb is None :
273+ # send the item to the server for embedding if not in the local cache
274+ log .debug (f"Sending to the server: { cache_key } " )
243275 url = (
244- f"/{ self .embedder_type } /{ self ._model } ?"
245- f"machine={ self .machine_id } "
246- f"&session={ self .session_id } &retry={ i } "
276+ f"/{ self .embedder_type } /{ self ._model } ?machine={ self .machine_id } "
277+ f"&session={ self .session_id } &retry={ num_repeats + 1 } "
247278 )
248279 emb = await self ._send_request (client , data_bytes , url )
249280 if emb is not None :
250281 self ._cache .add (cache_key , emb )
251- break # repeat only when embedding None
252- if proc_callback :
253- proc_callback (emb is not None )
254282
255- self .num_parallel_requests -= 1
256- return emb
283+ if emb is not None :
284+ # store result if embedding is successful
285+ log .debug (f"Successfully embedded: { cache_key } " )
286+ results [i ] = emb
287+ if proc_callback :
288+ proc_callback (emb is not None )
289+ elif num_repeats < self .MAX_REPEATS :
290+ log .debug (f"Not embedded successfully - reading to queue: { cache_key } " )
291+ # if embedding not successful put the item to queue to be handled at
292+ # the end - the item is put to the end since it is possible that server
293+ # still process the request and the result will be in the cache later
294+ # repeating the request immediately may result in another fail when
295+ # processing takes longer
296+ queue .put_nowait (TaskItem (i , data_instance , no_repeats = num_repeats + 1 ))
297+ queue .task_done ()
257298
258299 async def _send_request (
259300 self , client : AsyncClient , data : bytes , url : str
@@ -284,27 +325,23 @@ async def _send_request(
284325 response = await client .post (url , headers = headers , data = data )
285326 except ReadTimeout as ex :
286327 log .debug ("Read timeout" , exc_info = True )
287- # it happens when server do not respond in 60 seconds, in
288- # this case we return None and items will be resend later
328+ # it happens when server do not respond in time defined by timeout
329+ # return None and items will be resend later
289330
290331 # if it happens more than in ten consecutive cases it means
291332 # sth is wrong with embedder we stop embedding
292333 self .count_read_errors += 1
293-
294334 if self .count_read_errors >= self .max_errors :
295- self .num_parallel_requests = 0 # for safety reasons
296335 raise EmbeddingConnectionError from ex
297336 return None
298337 except (OSError , NetworkError ) as ex :
299338 log .debug ("Network error" , exc_info = True )
300- # it happens when no connection and items cannot be sent to the
301- # server
302- # we count number of consecutive errors
339+ # it happens when no connection and items cannot be sent to server
340+
303341 # if more than 10 consecutive errors it means there is no
304342 # connection so we stop embedding with EmbeddingConnectionError
305343 self .count_connection_errors += 1
306344 if self .count_connection_errors >= self .max_errors :
307- self .num_parallel_requests = 0 # for safety reasons
308345 raise EmbeddingConnectionError from ex
309346 return None
310347 except Exception :
0 commit comments