Skip to content

Commit b8928c2

Browse files
authored
Merge pull request #5835 from PrimozGodec/embedders-change-order
[ENH] Server embedder: use queue, handle unsuccessful requests at the end
2 parents 05fc0df + f3a9dde commit b8928c2

File tree

2 files changed

+193
-92
lines changed

2 files changed

+193
-92
lines changed

Orange/misc/server_embedder.py

Lines changed: 159 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,34 @@
33
import logging
44
import random
55
import uuid
6+
import warnings
7+
from collections import namedtuple
8+
from functools import partial
69
from json import JSONDecodeError
710
from os import getenv
811
from typing import Any, Callable, List, Optional
912

1013
from AnyQt.QtCore import QSettings
1114
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
15+
from numpy import linspace
1216

13-
from Orange.misc.utils.embedder_utils import (EmbedderCache,
14-
EmbeddingCancelledException,
15-
EmbeddingConnectionError,
16-
get_proxies)
17+
from Orange.misc.utils.embedder_utils import (
18+
EmbedderCache,
19+
EmbeddingCancelledException,
20+
EmbeddingConnectionError,
21+
get_proxies,
22+
)
23+
from Orange.util import dummy_callback
1724

1825
log = logging.getLogger(__name__)
26+
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))
1927

2028

2129
class ServerEmbedderCommunicator:
2230
"""
2331
This class needs to be inherited by the class which re-implements
24-
_encode_data_instance and defines self.content_type. For sending a table
25-
with data items use embedd_table function. This one is called with the
26-
complete Orange data Table. Then _encode_data_instance needs to extract
27-
data to be embedded from the RowInstance. For images, it takes the image
28-
path from the table, load image, and transform it into bytes.
32+
_encode_data_instance and defines self.content_type. For sending a list
33+
with data items use embedd_table function.
2934
3035
Attributes
3136
----------
@@ -58,8 +63,7 @@ def __init__(
5863
self._model = model_name
5964
self.embedder_type = embedder_type
6065

61-
# attribute that offers support for cancelling the embedding
62-
# if ran in another thread
66+
# remove in 3.33
6367
self._cancelled = False
6468

6569
self.machine_id = None
@@ -69,20 +73,22 @@ def __init__(
6973
) or str(uuid.getnode())
7074
except TypeError:
7175
self.machine_id = str(uuid.getnode())
72-
self.session_id = str(random.randint(1, 1e10))
76+
self.session_id = str(random.randint(1, int(1e10)))
7377

7478
self._cache = EmbedderCache(model_name)
7579

7680
# default embedding timeouts are too small we need to increase them
7781
self.timeout = 180
78-
self.num_parallel_requests = 0
79-
self.max_parallel = max_parallel_requests
82+
self.max_parallel_requests = max_parallel_requests
83+
8084
self.content_type = None # need to be set in a class inheriting
8185

8286
def embedd_data(
83-
self,
84-
data: List[Any],
85-
processed_callback: Callable[[bool], None] = None,
87+
self,
88+
data: List[Any],
89+
processed_callback: Optional[Callable] = None,
90+
*,
91+
callback: Callable = dummy_callback,
8692
) -> List[Optional[List[float]]]:
8793
"""
8894
This function repeats calling embedding function until all items
@@ -94,9 +100,12 @@ def embedd_data(
94100
data
95101
List with data that needs to be embedded.
96102
processed_callback
103+
Deprecated: remove in 3.33
97104
A function that is called after each item is embedded
98105
by either getting a successful response from the server,
99106
getting the result from cache or skipping the item.
107+
callback
108+
Callback for reporting the progress in share of embedded items
100109
101110
Returns
102111
-------
@@ -111,25 +120,26 @@ def embedd_data(
111120
EmbeddingCancelledException:
112121
If cancelled attribute is set to True (default=False).
113122
"""
114-
# if there is less items than 10 connection error should be raised
115-
# earlier
123+
# if there is less items than 10 connection error should be raised earlier
116124
self.max_errors = min(len(data) * self.MAX_REPEATS, 10)
117125

118126
loop = asyncio.new_event_loop()
119127
asyncio.set_event_loop(loop)
120128
try:
121129
embeddings = asyncio.get_event_loop().run_until_complete(
122-
self.embedd_batch(data, processed_callback)
130+
self.embedd_batch(data, processed_callback, callback=callback)
123131
)
124-
except Exception:
132+
finally:
125133
loop.close()
126-
raise
127134

128-
loop.close()
129135
return embeddings
130136

131137
async def embedd_batch(
132-
self, data: List[Any], proc_callback: Callable[[bool], None] = None
138+
self,
139+
data: List[Any],
140+
proc_callback: Optional[Callable] = None,
141+
*,
142+
callback: Callable = dummy_callback,
133143
) -> List[Optional[List[float]]]:
134144
"""
135145
Function perform embedding of a batch of data items.
@@ -138,10 +148,8 @@ async def embedd_batch(
138148
----------
139149
data
140150
A list of data that must be embedded.
141-
proc_callback
142-
A function that is called after each item is fully processed
143-
by either getting a successful response from the server,
144-
getting the result from cache or skipping the item.
151+
callback
152+
Callback for reporting the progress in share of embedded items
145153
146154
Returns
147155
-------
@@ -153,32 +161,79 @@ async def embedd_batch(
153161
EmbeddingCancelledException:
154162
If cancelled attribute is set to True (default=False).
155163
"""
156-
requests = []
164+
# in Orange 3.33 keep content of the if - remove if clause and complete else
165+
if proc_callback is None:
166+
progress_items = iter(linspace(0, 1, len(data)))
167+
168+
def success_callback():
169+
"""Callback called on every successful embedding"""
170+
callback(next(progress_items))
171+
else:
172+
warnings.warn(
173+
"proc_callback is deprecated and will be removed in version 3.33, "
174+
"use callback instead",
175+
FutureWarning,
176+
)
177+
success_callback = partial(proc_callback, True)
178+
179+
results = [None] * len(data)
180+
queue = asyncio.Queue()
181+
182+
# fill the queue with items to embedd
183+
for i, item in enumerate(data):
184+
queue.put_nowait(TaskItem(id=i, item=item, no_repeats=0))
185+
157186
async with AsyncClient(
158-
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
187+
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
159188
) as client:
160-
for p in data:
161-
if self._cancelled:
162-
raise EmbeddingCancelledException()
163-
requests.append(self._send_to_server(p, client, proc_callback))
189+
tasks = self._init_workers(client, queue, results, success_callback)
164190

165-
embeddings = await asyncio.gather(*requests)
166-
self._cache.persist_cache()
167-
assert self.num_parallel_requests == 0
191+
# wait for the queue to complete or one of workers to exit
192+
queue_complete = asyncio.create_task(queue.join())
193+
await asyncio.wait(
194+
[queue_complete, *tasks], return_when=asyncio.FIRST_COMPLETED
195+
)
168196

169-
return embeddings
197+
# Cancel worker tasks when done
198+
queue_complete.cancel()
199+
await self._cancel_workers(tasks)
170200

171-
async def __wait_until_released(self) -> None:
172-
while self.num_parallel_requests >= self.max_parallel:
173-
await asyncio.sleep(0.1)
201+
self._cache.persist_cache()
202+
return results
203+
204+
def _init_workers(self, client, queue, results, callback):
205+
"""Init required number of workers"""
206+
t = [
207+
asyncio.create_task(self._send_to_server(client, queue, results, callback))
208+
for _ in range(self.max_parallel_requests)
209+
]
210+
log.debug("Created %d workers", self.max_parallel_requests)
211+
return t
174212

213+
@staticmethod
214+
async def _cancel_workers(tasks):
215+
"""Cancel worker at the end"""
216+
log.debug("Canceling workers")
217+
try:
218+
# try to catch any potential exceptions
219+
await asyncio.gather(*tasks)
220+
except Exception as ex:
221+
# raise exceptions gathered from an failed worker
222+
raise ex
223+
finally:
224+
# cancel all tasks in both cases
225+
for task in tasks:
226+
task.cancel()
227+
# Wait until all worker tasks are cancelled.
228+
await asyncio.gather(*tasks, return_exceptions=True)
229+
log.debug("All workers canceled")
230+
231+
# remove in 3.33
175232
def __check_cancelled(self):
176233
if self._cancelled:
177234
raise EmbeddingCancelledException()
178235

179-
async def _encode_data_instance(
180-
self, data_instance: Any
181-
) -> Optional[bytes]:
236+
async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
182237
"""
183238
The reimplementation of this function must implement the procedure
184239
to encode the data item in a string format that will be sent to the
@@ -197,63 +252,73 @@ async def _encode_data_instance(
197252
raise NotImplementedError
198253

199254
async def _send_to_server(
200-
self,
201-
data_instance: Any,
202-
client: AsyncClient,
203-
proc_callback: Callable[[bool], None] = None,
204-
) -> Optional[List[float]]:
255+
self,
256+
client: AsyncClient,
257+
queue: asyncio.Queue,
258+
results: List,
259+
proc_callback: Callable,
260+
):
205261
"""
206-
Function get an data instance. It extract data from it and send them to
207-
server and retrieve responses.
262+
Worker that embedds data. It is pulling items from the queue until
263+
it is empty. It is runs until anything is in the queue, or it is canceled
208264
209265
Parameters
210266
----------
211-
data_instance
212-
Single row of the input table.
213267
client
214268
HTTPX client that communicates with the server
269+
queue
270+
The queue with items of type TaskItem to be embedded
271+
results
272+
The list to append results in. The list has length equal to numbers
273+
of all items to embedd. The result need to be inserted at the index
274+
defined in queue items.
215275
proc_callback
216276
A function that is called after each item is fully processed
217277
by either getting a successful response from the server,
218278
getting the result from cache or skipping the item.
219-
220-
Returns
221-
-------
222-
Embedding. For items that are not successfully embedded returns None.
223279
"""
224-
await self.__wait_until_released()
225-
self.__check_cancelled()
226-
227-
self.num_parallel_requests += 1
228-
# load bytes
229-
data_bytes = await self._encode_data_instance(data_instance)
230-
if data_bytes is None:
231-
self.num_parallel_requests -= 1
232-
return None
233-
234-
# if data in cache return it
235-
cache_key = self._cache.md5_hash(data_bytes)
236-
emb = self._cache.get_cached_result_or_none(cache_key)
237-
238-
if emb is None:
239-
# in case that embedding not sucessfull resend it to the server
240-
# maximally for MAX_REPEATS time
241-
for i in range(1, self.MAX_REPEATS + 1):
242-
self.__check_cancelled()
280+
while not queue.empty():
281+
# remove in 3.33
282+
self.__check_cancelled()
283+
284+
# get item from the queue
285+
i, data_instance, num_repeats = await queue.get()
286+
287+
# load bytes
288+
data_bytes = await self._encode_data_instance(data_instance)
289+
if data_bytes is None:
290+
continue
291+
292+
# retrieve embedded item from the local cache
293+
cache_key = self._cache.md5_hash(data_bytes)
294+
log.debug("Embedding %s", cache_key)
295+
emb = self._cache.get_cached_result_or_none(cache_key)
296+
297+
if emb is None:
298+
# send the item to the server for embedding if not in the local cache
299+
log.debug("Sending to the server: %s", cache_key)
243300
url = (
244-
f"/{self.embedder_type}/{self._model}?"
245-
f"machine={self.machine_id}"
246-
f"&session={self.session_id}&retry={i}"
301+
f"/{self.embedder_type}/{self._model}?machine={self.machine_id}"
302+
f"&session={self.session_id}&retry={num_repeats+1}"
247303
)
248304
emb = await self._send_request(client, data_bytes, url)
249305
if emb is not None:
250306
self._cache.add(cache_key, emb)
251-
break # repeat only when embedding None
252-
if proc_callback:
253-
proc_callback(emb is not None)
254307

255-
self.num_parallel_requests -= 1
256-
return emb
308+
if emb is not None:
309+
# store result if embedding is successful
310+
log.debug("Successfully embedded: %s", cache_key)
311+
results[i] = emb
312+
proc_callback()
313+
elif num_repeats+1 < self.MAX_REPEATS:
314+
log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
315+
# if embedding not successful put the item to queue to be handled at
316+
# the end - the item is put to the end since it is possible that server
317+
# still process the request and the result will be in the cache later
318+
# repeating the request immediately may result in another fail when
319+
# processing takes longer
320+
queue.put_nowait(TaskItem(i, data_instance, no_repeats=num_repeats+1))
321+
queue.task_done()
257322

258323
async def _send_request(
259324
self, client: AsyncClient, data: bytes, url: str
@@ -284,27 +349,23 @@ async def _send_request(
284349
response = await client.post(url, headers=headers, data=data)
285350
except ReadTimeout as ex:
286351
log.debug("Read timeout", exc_info=True)
287-
# it happens when server do not respond in 60 seconds, in
288-
# this case we return None and items will be resend later
352+
# it happens when server do not respond in time defined by timeout
353+
# return None and items will be resend later
289354

290355
# if it happens more than in ten consecutive cases it means
291356
# sth is wrong with embedder we stop embedding
292357
self.count_read_errors += 1
293-
294358
if self.count_read_errors >= self.max_errors:
295-
self.num_parallel_requests = 0 # for safety reasons
296359
raise EmbeddingConnectionError from ex
297360
return None
298361
except (OSError, NetworkError) as ex:
299362
log.debug("Network error", exc_info=True)
300-
# it happens when no connection and items cannot be sent to the
301-
# server
302-
# we count number of consecutive errors
363+
# it happens when no connection and items cannot be sent to server
364+
303365
# if more than 10 consecutive errors it means there is no
304366
# connection so we stop embedding with EmbeddingConnectionError
305367
self.count_connection_errors += 1
306368
if self.count_connection_errors >= self.max_errors:
307-
self.num_parallel_requests = 0 # for safety reasons
308369
raise EmbeddingConnectionError from ex
309370
return None
310371
except Exception:
@@ -343,5 +404,11 @@ def _parse_response(response: Response) -> Optional[List[float]]:
343404
def clear_cache(self):
344405
self._cache.clear_cache()
345406

407+
# remove in 3.33
346408
def set_cancelled(self):
409+
warnings.warn(
410+
"set_cancelled is deprecated and will be removed in version 3.33, "
411+
"the process can be canceled by raising Error in callback",
412+
FutureWarning,
413+
)
347414
self._cancelled = True

0 commit comments

Comments
 (0)