33import logging
44import random
55import uuid
6+ from collections import namedtuple
67from json import JSONDecodeError
78from os import getenv
89from typing import Any , Callable , List , Optional
910
1011from AnyQt .QtCore import QSettings
1112from httpx import AsyncClient , NetworkError , ReadTimeout , Response
1213
13- from Orange .misc .utils .embedder_utils import (EmbedderCache ,
14- EmbeddingCancelledException ,
15- EmbeddingConnectionError ,
16- get_proxies )
14+ from Orange .misc .utils .embedder_utils import (
15+ EmbedderCache ,
16+ EmbeddingCancelledException ,
17+ EmbeddingConnectionError ,
18+ get_proxies ,
19+ )
1720
1821log = logging .getLogger (__name__ )
22+ TaskItem = namedtuple ("TaskItem" , ("id" , "item" , "no_repeats" ))
1923
2024
2125class ServerEmbedderCommunicator :
2226 """
2327 This class needs to be inherited by the class which re-implements
24- _encode_data_instance and defines self.content_type. For sending a table
25- with data items use embedd_table function. This one is called with the
26- complete Orange data Table. Then _encode_data_instance needs to extract
27- data to be embedded from the RowInstance. For images, it takes the image
28- path from the table, load image, and transform it into bytes.
28+ _encode_data_instance and defines self.content_type. For sending a list
29+ with data items use embedd_table function.
2930
3031 Attributes
3132 ----------
@@ -69,14 +70,14 @@ def __init__(
6970 ) or str (uuid .getnode ())
7071 except TypeError :
7172 self .machine_id = str (uuid .getnode ())
72- self .session_id = str (random .randint (1 , 1e10 ))
73+ self .session_id = str (random .randint (1 , int ( 1e10 ) ))
7374
7475 self ._cache = EmbedderCache (model_name )
7576
7677 # default embedding timeouts are too small we need to increase them
7778 self .timeout = 180
78- self .num_parallel_requests = 0
79- self . max_parallel = max_parallel_requests
79+ self .max_parallel_requests = max_parallel_requests
80+
8081 self .content_type = None # need to be set in a class inheriting
8182
8283 def embedd_data (
@@ -111,8 +112,7 @@ def embedd_data(
111112 EmbeddingCancelledException:
112113 If cancelled attribute is set to True (default=False).
113114 """
114- # if there is less items than 10 connection error should be raised
115- # earlier
115+ # if there is less items than 10 connection error should be raised earlier
116116 self .max_errors = min (len (data ) * self .MAX_REPEATS , 10 )
117117
118118 loop = asyncio .new_event_loop ()
@@ -121,11 +121,9 @@ def embedd_data(
121121 embeddings = asyncio .get_event_loop ().run_until_complete (
122122 self .embedd_batch (data , processed_callback )
123123 )
124- except Exception :
124+ finally :
125125 loop .close ()
126- raise
127126
128- loop .close ()
129127 return embeddings
130128
131129 async def embedd_batch (
@@ -153,32 +151,63 @@ async def embedd_batch(
153151 EmbeddingCancelledException:
154152 If cancelled attribute is set to True (default=False).
155153 """
156- requests = []
154+ results = [None ] * len (data )
155+ queue = asyncio .Queue ()
156+
157+ # fill the queue with items to embedd
158+ for i , item in enumerate (data ):
159+ queue .put_nowait (TaskItem (id = i , item = item , no_repeats = 0 ))
160+
157161 async with AsyncClient (
158- timeout = self .timeout , base_url = self .server_url , proxies = get_proxies ()
162+ timeout = self .timeout , base_url = self .server_url , proxies = get_proxies ()
159163 ) as client :
160- for p in data :
161- if self ._cancelled :
162- raise EmbeddingCancelledException ()
163- requests .append (self ._send_to_server (p , client , proc_callback ))
164+ tasks = self ._init_workers (client , queue , results , proc_callback )
165+
166+ # wait for the queue to complete or one of workers to exit
167+ queue_complete = asyncio .create_task (queue .join ())
168+ await asyncio .wait (
169+ [queue_complete , * tasks ], return_when = asyncio .FIRST_COMPLETED
170+ )
171+
172+ # Cancel worker tasks when done
173+ queue_complete .cancel ()
174+ await self ._cancel_workers (tasks )
164175
165- embeddings = await asyncio .gather (* requests )
166176 self ._cache .persist_cache ()
167- assert self . num_parallel_requests == 0
177+ return results
168178
169- return embeddings
179+ def _init_workers (self , client , queue , results , callback ):
180+ """Init required number of workers"""
181+ t = [
182+ asyncio .create_task (self ._send_to_server (client , queue , results , callback ))
183+ for _ in range (self .max_parallel_requests )
184+ ]
185+ log .debug ("Created %d workers" , self .max_parallel_requests )
186+ return t
170187
171- async def __wait_until_released (self ) -> None :
172- while self .num_parallel_requests >= self .max_parallel :
173- await asyncio .sleep (0.1 )
188+ @staticmethod
189+ async def _cancel_workers (tasks ):
190+ """Cancel worker at the end"""
191+ log .debug ("Canceling workers" )
192+ try :
193+ # try to catch any potential exceptions
194+ await asyncio .gather (* tasks )
195+ except Exception as ex :
196+ # raise exceptions gathered from an failed worker
197+ raise ex
198+ finally :
199+ # cancel all tasks in both cases
200+ for task in tasks :
201+ task .cancel ()
202+ # Wait until all worker tasks are cancelled.
203+ await asyncio .gather (* tasks , return_exceptions = True )
204+ log .debug ("All workers canceled" )
174205
175206 def __check_cancelled (self ):
176207 if self ._cancelled :
177208 raise EmbeddingCancelledException ()
178209
179- async def _encode_data_instance (
180- self , data_instance : Any
181- ) -> Optional [bytes ]:
210+ async def _encode_data_instance (self , data_instance : Any ) -> Optional [bytes ]:
182211 """
183212 The reimplementation of this function must implement the procedure
184213 to encode the data item in a string format that will be sent to the
@@ -197,63 +226,74 @@ async def _encode_data_instance(
197226 raise NotImplementedError
198227
199228 async def _send_to_server (
200- self ,
201- data_instance : Any ,
202- client : AsyncClient ,
203- proc_callback : Callable [[bool ], None ] = None ,
204- ) -> Optional [List [float ]]:
229+ self ,
230+ client : AsyncClient ,
231+ queue : asyncio .Queue ,
232+ results : List ,
233+ proc_callback : Callable [[bool ], None ] = None ,
234+ ):
205235 """
206- Function get an data instance . It extract data from it and send them to
207- server and retrieve responses.
236+ Worker that embedds data. It is pulling items from the until the queue
237+ is empty. It is canceled by embedd_batch all tasks are finished
208238
209239 Parameters
210240 ----------
211- data_instance
212- Single row of the input table.
213241 client
214242 HTTPX client that communicates with the server
243+ queue
244+ The queue with items of type TaskItem to be embedded
245+ results
246+ The list to append results in. The list has length equal to numbers
247+ of all items to embedd. The result need to be inserted at the index
248+ defined in queue items.
215249 proc_callback
216250 A function that is called after each item is fully processed
217251 by either getting a successful response from the server,
218252 getting the result from cache or skipping the item.
219-
220- Returns
221- -------
222- Embedding. For items that are not successfully embedded returns None.
223253 """
224- await self .__wait_until_released ()
225- self .__check_cancelled ()
226-
227- self .num_parallel_requests += 1
228- # load bytes
229- data_bytes = await self ._encode_data_instance (data_instance )
230- if data_bytes is None :
231- self .num_parallel_requests -= 1
232- return None
233-
234- # if data in cache return it
235- cache_key = self ._cache .md5_hash (data_bytes )
236- emb = self ._cache .get_cached_result_or_none (cache_key )
237-
238- if emb is None :
239- # in case that embedding not sucessfull resend it to the server
240- # maximally for MAX_REPEATS time
241- for i in range (1 , self .MAX_REPEATS + 1 ):
242- self .__check_cancelled ()
254+ while not queue .empty ():
255+ self .__check_cancelled ()
256+
257+ # get item from the queue
258+ i , data_instance , num_repeats = await queue .get ()
259+ num_repeats += 1
260+
261+ # load bytes
262+ data_bytes = await self ._encode_data_instance (data_instance )
263+ if data_bytes is None :
264+ continue
265+
266+ # retrieve embedded item from the local cache
267+ cache_key = self ._cache .md5_hash (data_bytes )
268+ log .debug ("Embedding %s" , cache_key )
269+ emb = self ._cache .get_cached_result_or_none (cache_key )
270+
271+ if emb is None :
272+ # send the item to the server for embedding if not in the local cache
273+ log .debug ("Sending to the server: %s" , cache_key )
243274 url = (
244- f"/{ self .embedder_type } /{ self ._model } ?"
245- f"machine={ self .machine_id } "
246- f"&session={ self .session_id } &retry={ i } "
275+ f"/{ self .embedder_type } /{ self ._model } ?machine={ self .machine_id } "
276+ f"&session={ self .session_id } &retry={ num_repeats } "
247277 )
248278 emb = await self ._send_request (client , data_bytes , url )
249279 if emb is not None :
250280 self ._cache .add (cache_key , emb )
251- break # repeat only when embedding None
252- if proc_callback :
253- proc_callback (emb is not None )
254281
255- self .num_parallel_requests -= 1
256- return emb
282+ if emb is not None :
283+ # store result if embedding is successful
284+ log .debug ("Successfully embedded: %s" , cache_key )
285+ results [i ] = emb
286+ if proc_callback :
287+ proc_callback (emb is not None )
288+ elif num_repeats < self .MAX_REPEATS :
289+ log .debug ("Embedding unsuccessful - reading to queue: %s" , cache_key )
290+ # if embedding not successful put the item to queue to be handled at
291+ # the end - the item is put to the end since it is possible that server
292+ # still process the request and the result will be in the cache later
293+ # repeating the request immediately may result in another fail when
294+ # processing takes longer
295+ queue .put_nowait (TaskItem (i , data_instance , no_repeats = num_repeats ))
296+ queue .task_done ()
257297
258298 async def _send_request (
259299 self , client : AsyncClient , data : bytes , url : str
@@ -284,27 +324,23 @@ async def _send_request(
284324 response = await client .post (url , headers = headers , data = data )
285325 except ReadTimeout as ex :
286326 log .debug ("Read timeout" , exc_info = True )
287- # it happens when server do not respond in 60 seconds, in
288- # this case we return None and items will be resend later
327+ # it happens when server do not respond in time defined by timeout
328+ # return None and items will be resend later
289329
290330 # if it happens more than in ten consecutive cases it means
291331 # sth is wrong with embedder we stop embedding
292332 self .count_read_errors += 1
293-
294333 if self .count_read_errors >= self .max_errors :
295- self .num_parallel_requests = 0 # for safety reasons
296334 raise EmbeddingConnectionError from ex
297335 return None
298336 except (OSError , NetworkError ) as ex :
299337 log .debug ("Network error" , exc_info = True )
300- # it happens when no connection and items cannot be sent to the
301- # server
302- # we count number of consecutive errors
338+ # it happens when no connection and items cannot be sent to server
339+
303340 # if more than 10 consecutive errors it means there is no
304341 # connection so we stop embedding with EmbeddingConnectionError
305342 self .count_connection_errors += 1
306343 if self .count_connection_errors >= self .max_errors :
307- self .num_parallel_requests = 0 # for safety reasons
308344 raise EmbeddingConnectionError from ex
309345 return None
310346 except Exception :
0 commit comments