33import logging
44import random
55import uuid
6+ import warnings
7+ from collections import namedtuple
8+ from functools import partial
69from json import JSONDecodeError
710from os import getenv
811from typing import Any , Callable , List , Optional
912
1013from AnyQt .QtCore import QSettings
1114from httpx import AsyncClient , NetworkError , ReadTimeout , Response
15+ from numpy import linspace
1216
13- from Orange .misc .utils .embedder_utils import (EmbedderCache ,
14- EmbeddingCancelledException ,
15- EmbeddingConnectionError ,
16- get_proxies )
17+ from Orange .misc .utils .embedder_utils import (
18+ EmbedderCache ,
19+ EmbeddingCancelledException ,
20+ EmbeddingConnectionError ,
21+ get_proxies ,
22+ )
23+ from Orange .util import dummy_callback
1724
1825log = logging .getLogger (__name__ )
26+ TaskItem = namedtuple ("TaskItem" , ("id" , "item" , "no_repeats" ))
1927
2028
2129class ServerEmbedderCommunicator :
2230 """
2331 This class needs to be inherited by the class which re-implements
24- _encode_data_instance and defines self.content_type. For sending a table
25- with data items use embedd_table function. This one is called with the
26- complete Orange data Table. Then _encode_data_instance needs to extract
27- data to be embedded from the RowInstance. For images, it takes the image
28- path from the table, load image, and transform it into bytes.
32+ _encode_data_instance and defines self.content_type. For sending a list
33+ with data items use embedd_table function.
2934
3035 Attributes
3136 ----------
@@ -58,8 +63,7 @@ def __init__(
5863 self ._model = model_name
5964 self .embedder_type = embedder_type
6065
61- # attribute that offers support for cancelling the embedding
62- # if ran in another thread
66+ # remove in 3.33
6367 self ._cancelled = False
6468
6569 self .machine_id = None
@@ -69,20 +73,22 @@ def __init__(
6973 ) or str (uuid .getnode ())
7074 except TypeError :
7175 self .machine_id = str (uuid .getnode ())
72- self .session_id = str (random .randint (1 , 1e10 ))
76+ self .session_id = str (random .randint (1 , int ( 1e10 ) ))
7377
7478 self ._cache = EmbedderCache (model_name )
7579
7680 # default embedding timeouts are too small we need to increase them
7781 self .timeout = 180
78- self .num_parallel_requests = 0
79- self . max_parallel = max_parallel_requests
82+ self .max_parallel_requests = max_parallel_requests
83+
8084 self .content_type = None # need to be set in a class inheriting
8185
8286 def embedd_data (
83- self ,
84- data : List [Any ],
85- processed_callback : Callable [[bool ], None ] = None ,
87+ self ,
88+ data : List [Any ],
89+ processed_callback : Optional [Callable ] = None ,
90+ * ,
91+ callback : Callable = dummy_callback ,
8692 ) -> List [Optional [List [float ]]]:
8793 """
8894 This function repeats calling embedding function until all items
@@ -94,9 +100,12 @@ def embedd_data(
94100 data
95101 List with data that needs to be embedded.
96102 processed_callback
103+ Deprecated: remove in 3.33
97104 A function that is called after each item is embedded
98105 by either getting a successful response from the server,
99106 getting the result from cache or skipping the item.
107+ callback
108+ Callback for reporting the progress in share of embedded items
100109
101110 Returns
102111 -------
@@ -111,25 +120,26 @@ def embedd_data(
111120 EmbeddingCancelledException:
112121 If cancelled attribute is set to True (default=False).
113122 """
114- # if there is less items than 10 connection error should be raised
115- # earlier
123+ # if there is less items than 10 connection error should be raised earlier
116124 self .max_errors = min (len (data ) * self .MAX_REPEATS , 10 )
117125
118126 loop = asyncio .new_event_loop ()
119127 asyncio .set_event_loop (loop )
120128 try :
121129 embeddings = asyncio .get_event_loop ().run_until_complete (
122- self .embedd_batch (data , processed_callback )
130+ self .embedd_batch (data , processed_callback , callback = callback )
123131 )
124- except Exception :
132+ finally :
125133 loop .close ()
126- raise
127134
128- loop .close ()
129135 return embeddings
130136
131137 async def embedd_batch (
132- self , data : List [Any ], proc_callback : Callable [[bool ], None ] = None
138+ self ,
139+ data : List [Any ],
140+ proc_callback : Optional [Callable ] = None ,
141+ * ,
142+ callback : Callable = dummy_callback ,
133143 ) -> List [Optional [List [float ]]]:
134144 """
135145 Function perform embedding of a batch of data items.
@@ -138,10 +148,8 @@ async def embedd_batch(
138148 ----------
139149 data
140150 A list of data that must be embedded.
141- proc_callback
142- A function that is called after each item is fully processed
143- by either getting a successful response from the server,
144- getting the result from cache or skipping the item.
151+ callback
152+ Callback for reporting the progress in share of embedded items
145153
146154 Returns
147155 -------
@@ -153,32 +161,79 @@ async def embedd_batch(
153161 EmbeddingCancelledException:
154162 If cancelled attribute is set to True (default=False).
155163 """
156- requests = []
164+ # in Orange 3.33 keep content of the if - remove if clause and complete else
165+ if proc_callback is None :
166+ progress_items = iter (linspace (0 , 1 , len (data )))
167+
168+ def success_callback ():
169+ """Callback called on every successful embedding"""
170+ callback (next (progress_items ))
171+ else :
172+ warnings .warn (
173+ "proc_callback is deprecated and will be removed in version 3.33, "
174+ "use callback instead" ,
175+ FutureWarning ,
176+ )
177+ success_callback = partial (proc_callback , True )
178+
179+ results = [None ] * len (data )
180+ queue = asyncio .Queue ()
181+
182+ # fill the queue with items to embedd
183+ for i , item in enumerate (data ):
184+ queue .put_nowait (TaskItem (id = i , item = item , no_repeats = 0 ))
185+
157186 async with AsyncClient (
158- timeout = self .timeout , base_url = self .server_url , proxies = get_proxies ()
187+ timeout = self .timeout , base_url = self .server_url , proxies = get_proxies ()
159188 ) as client :
160- for p in data :
161- if self ._cancelled :
162- raise EmbeddingCancelledException ()
163- requests .append (self ._send_to_server (p , client , proc_callback ))
189+ tasks = self ._init_workers (client , queue , results , success_callback )
164190
165- embeddings = await asyncio .gather (* requests )
166- self ._cache .persist_cache ()
167- assert self .num_parallel_requests == 0
191+ # wait for the queue to complete or one of workers to exit
192+ queue_complete = asyncio .create_task (queue .join ())
193+ await asyncio .wait (
194+ [queue_complete , * tasks ], return_when = asyncio .FIRST_COMPLETED
195+ )
168196
169- return embeddings
197+ # Cancel worker tasks when done
198+ queue_complete .cancel ()
199+ await self ._cancel_workers (tasks )
170200
171- async def __wait_until_released (self ) -> None :
172- while self .num_parallel_requests >= self .max_parallel :
173- await asyncio .sleep (0.1 )
201+ self ._cache .persist_cache ()
202+ return results
203+
204+ def _init_workers (self , client , queue , results , callback ):
205+ """Init required number of workers"""
206+ t = [
207+ asyncio .create_task (self ._send_to_server (client , queue , results , callback ))
208+ for _ in range (self .max_parallel_requests )
209+ ]
210+ log .debug ("Created %d workers" , self .max_parallel_requests )
211+ return t
174212
213+ @staticmethod
214+ async def _cancel_workers (tasks ):
215+ """Cancel worker at the end"""
216+ log .debug ("Canceling workers" )
217+ try :
218+ # try to catch any potential exceptions
219+ await asyncio .gather (* tasks )
220+ except Exception as ex :
221+ # raise exceptions gathered from an failed worker
222+ raise ex
223+ finally :
224+ # cancel all tasks in both cases
225+ for task in tasks :
226+ task .cancel ()
227+ # Wait until all worker tasks are cancelled.
228+ await asyncio .gather (* tasks , return_exceptions = True )
229+ log .debug ("All workers canceled" )
230+
231+ # remove in 3.33
175232 def __check_cancelled (self ):
176233 if self ._cancelled :
177234 raise EmbeddingCancelledException ()
178235
179- async def _encode_data_instance (
180- self , data_instance : Any
181- ) -> Optional [bytes ]:
236+ async def _encode_data_instance (self , data_instance : Any ) -> Optional [bytes ]:
182237 """
183238 The reimplementation of this function must implement the procedure
184239 to encode the data item in a string format that will be sent to the
@@ -197,63 +252,73 @@ async def _encode_data_instance(
197252 raise NotImplementedError
198253
199254 async def _send_to_server (
200- self ,
201- data_instance : Any ,
202- client : AsyncClient ,
203- proc_callback : Callable [[bool ], None ] = None ,
204- ) -> Optional [List [float ]]:
255+ self ,
256+ client : AsyncClient ,
257+ queue : asyncio .Queue ,
258+ results : List ,
259+ proc_callback : Callable ,
260+ ):
205261 """
206- Function get an data instance . It extract data from it and send them to
207- server and retrieve responses.
262+ Worker that embedds data. It is pulling items from the queue until
263+ it is empty. It is runs until anything is in the queue, or it is canceled
208264
209265 Parameters
210266 ----------
211- data_instance
212- Single row of the input table.
213267 client
214268 HTTPX client that communicates with the server
269+ queue
270+ The queue with items of type TaskItem to be embedded
271+ results
272+ The list to append results in. The list has length equal to numbers
273+ of all items to embedd. The result need to be inserted at the index
274+ defined in queue items.
215275 proc_callback
216276 A function that is called after each item is fully processed
217277 by either getting a successful response from the server,
218278 getting the result from cache or skipping the item.
219-
220- Returns
221- -------
222- Embedding. For items that are not successfully embedded returns None.
223279 """
224- await self .__wait_until_released ()
225- self .__check_cancelled ()
226-
227- self .num_parallel_requests += 1
228- # load bytes
229- data_bytes = await self ._encode_data_instance (data_instance )
230- if data_bytes is None :
231- self .num_parallel_requests -= 1
232- return None
233-
234- # if data in cache return it
235- cache_key = self ._cache .md5_hash (data_bytes )
236- emb = self ._cache .get_cached_result_or_none (cache_key )
237-
238- if emb is None :
239- # in case that embedding not sucessfull resend it to the server
240- # maximally for MAX_REPEATS time
241- for i in range (1 , self .MAX_REPEATS + 1 ):
242- self .__check_cancelled ()
280+ while not queue .empty ():
281+ # remove in 3.33
282+ self .__check_cancelled ()
283+
284+ # get item from the queue
285+ i , data_instance , num_repeats = await queue .get ()
286+
287+ # load bytes
288+ data_bytes = await self ._encode_data_instance (data_instance )
289+ if data_bytes is None :
290+ continue
291+
292+ # retrieve embedded item from the local cache
293+ cache_key = self ._cache .md5_hash (data_bytes )
294+ log .debug ("Embedding %s" , cache_key )
295+ emb = self ._cache .get_cached_result_or_none (cache_key )
296+
297+ if emb is None :
298+ # send the item to the server for embedding if not in the local cache
299+ log .debug ("Sending to the server: %s" , cache_key )
243300 url = (
244- f"/{ self .embedder_type } /{ self ._model } ?"
245- f"machine={ self .machine_id } "
246- f"&session={ self .session_id } &retry={ i } "
301+ f"/{ self .embedder_type } /{ self ._model } ?machine={ self .machine_id } "
302+ f"&session={ self .session_id } &retry={ num_repeats + 1 } "
247303 )
248304 emb = await self ._send_request (client , data_bytes , url )
249305 if emb is not None :
250306 self ._cache .add (cache_key , emb )
251- break # repeat only when embedding None
252- if proc_callback :
253- proc_callback (emb is not None )
254307
255- self .num_parallel_requests -= 1
256- return emb
308+ if emb is not None :
309+ # store result if embedding is successful
310+ log .debug ("Successfully embedded: %s" , cache_key )
311+ results [i ] = emb
312+ proc_callback ()
313+ elif num_repeats + 1 < self .MAX_REPEATS :
314+ log .debug ("Embedding unsuccessful - reading to queue: %s" , cache_key )
315+ # if embedding not successful put the item to queue to be handled at
316+ # the end - the item is put to the end since it is possible that server
317+ # still process the request and the result will be in the cache later
318+ # repeating the request immediately may result in another fail when
319+ # processing takes longer
320+ queue .put_nowait (TaskItem (i , data_instance , no_repeats = num_repeats + 1 ))
321+ queue .task_done ()
257322
258323 async def _send_request (
259324 self , client : AsyncClient , data : bytes , url : str
@@ -284,27 +349,23 @@ async def _send_request(
284349 response = await client .post (url , headers = headers , data = data )
285350 except ReadTimeout as ex :
286351 log .debug ("Read timeout" , exc_info = True )
287- # it happens when server do not respond in 60 seconds, in
288- # this case we return None and items will be resend later
352+ # it happens when server do not respond in time defined by timeout
353+ # return None and items will be resend later
289354
290355 # if it happens more than in ten consecutive cases it means
291356 # sth is wrong with embedder we stop embedding
292357 self .count_read_errors += 1
293-
294358 if self .count_read_errors >= self .max_errors :
295- self .num_parallel_requests = 0 # for safety reasons
296359 raise EmbeddingConnectionError from ex
297360 return None
298361 except (OSError , NetworkError ) as ex :
299362 log .debug ("Network error" , exc_info = True )
300- # it happens when no connection and items cannot be sent to the
301- # server
302- # we count number of consecutive errors
363+ # it happens when no connection and items cannot be sent to server
364+
303365 # if more than 10 consecutive errors it means there is no
304366 # connection so we stop embedding with EmbeddingConnectionError
305367 self .count_connection_errors += 1
306368 if self .count_connection_errors >= self .max_errors :
307- self .num_parallel_requests = 0 # for safety reasons
308369 raise EmbeddingConnectionError from ex
309370 return None
310371 except Exception :
@@ -343,5 +404,11 @@ def _parse_response(response: Response) -> Optional[List[float]]:
343404 def clear_cache (self ):
344405 self ._cache .clear_cache ()
345406
407+ # remove in 3.33
346408 def set_cancelled (self ):
409+ warnings .warn (
410+ "set_cancelled is deprecated and will be removed in version 3.33, "
411+ "the process can be canceled by raising Error in callback" ,
412+ FutureWarning ,
413+ )
347414 self ._cancelled = True
0 commit comments