Skip to content

Commit 3782737

Browse files
committed
Server embedder: use queue, handle unsuccessful requests at the end
1 parent c38b96f commit 3782737

File tree

2 files changed

+118
-77
lines changed

2 files changed

+118
-77
lines changed

Orange/misc/server_embedder.py

Lines changed: 114 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,30 @@
33
import logging
44
import random
55
import uuid
6+
from collections import namedtuple
67
from json import JSONDecodeError
78
from os import getenv
89
from typing import Any, Callable, List, Optional
910

1011
from AnyQt.QtCore import QSettings
1112
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
1213

13-
from Orange.misc.utils.embedder_utils import (EmbedderCache,
14-
EmbeddingCancelledException,
15-
EmbeddingConnectionError,
16-
get_proxies)
14+
from Orange.misc.utils.embedder_utils import (
15+
EmbedderCache,
16+
EmbeddingCancelledException,
17+
EmbeddingConnectionError,
18+
get_proxies,
19+
)
1720

1821
log = logging.getLogger(__name__)
22+
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))
1923

2024

2125
class ServerEmbedderCommunicator:
2226
"""
2327
This class needs to be inherited by the class which re-implements
24-
_encode_data_instance and defines self.content_type. For sending a table
25-
with data items use embedd_table function. This one is called with the
26-
complete Orange data Table. Then _encode_data_instance needs to extract
27-
data to be embedded from the RowInstance. For images, it takes the image
28-
path from the table, load image, and transform it into bytes.
28+
_encode_data_instance and defines self.content_type. For sending a list
29+
with data items use embedd_table function.
2930
3031
Attributes
3132
----------
@@ -69,14 +70,14 @@ def __init__(
6970
) or str(uuid.getnode())
7071
except TypeError:
7172
self.machine_id = str(uuid.getnode())
72-
self.session_id = str(random.randint(1, 1e10))
73+
self.session_id = str(random.randint(1, int(1e10)))
7374

7475
self._cache = EmbedderCache(model_name)
7576

7677
# default embedding timeouts are too small we need to increase them
7778
self.timeout = 180
78-
self.num_parallel_requests = 0
79-
self.max_parallel = max_parallel_requests
79+
self.max_parallel_requests = max_parallel_requests
80+
8081
self.content_type = None # need to be set in a class inheriting
8182

8283
def embedd_data(
@@ -111,8 +112,7 @@ def embedd_data(
111112
EmbeddingCancelledException:
112113
If cancelled attribute is set to True (default=False).
113114
"""
114-
# if there is less items than 10 connection error should be raised
115-
# earlier
115+
# if there is less items than 10 connection error should be raised earlier
116116
self.max_errors = min(len(data) * self.MAX_REPEATS, 10)
117117

118118
loop = asyncio.new_event_loop()
@@ -122,10 +122,10 @@ def embedd_data(
122122
self.embedd_batch(data, processed_callback)
123123
)
124124
except Exception:
125-
loop.close()
126125
raise
126+
finally:
127+
loop.close()
127128

128-
loop.close()
129129
return embeddings
130130

131131
async def embedd_batch(
@@ -153,32 +153,63 @@ async def embedd_batch(
153153
EmbeddingCancelledException:
154154
If cancelled attribute is set to True (default=False).
155155
"""
156-
requests = []
156+
results = [None] * len(data)
157+
queue = asyncio.Queue()
158+
159+
# fill the queue with items to embedd
160+
for i, item in enumerate(data):
161+
queue.put_nowait(TaskItem(id=i, item=item, no_repeats=0))
162+
157163
async with AsyncClient(
158-
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
164+
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
159165
) as client:
160-
for p in data:
161-
if self._cancelled:
162-
raise EmbeddingCancelledException()
163-
requests.append(self._send_to_server(p, client, proc_callback))
166+
tasks = self._init_workers(client, queue, results, proc_callback)
167+
168+
# wait for the queue to complete or one of workers to exit
169+
queue_complete = asyncio.create_task(queue.join())
170+
await asyncio.wait(
171+
[queue_complete, *tasks], return_when=asyncio.FIRST_COMPLETED
172+
)
173+
174+
# Cancel worker tasks when done
175+
queue_complete.cancel()
176+
await self._cancel_workers(tasks)
164177

165-
embeddings = await asyncio.gather(*requests)
166178
self._cache.persist_cache()
167-
assert self.num_parallel_requests == 0
179+
return results
168180

169-
return embeddings
181+
def _init_workers(self, client, queue, results, callback):
182+
"""Init required number of workers"""
183+
t = [
184+
asyncio.create_task(self._send_to_server(client, queue, results, callback))
185+
for _ in range(self.max_parallel_requests)
186+
]
187+
log.debug(f"Created {self.max_parallel_requests} workers")
188+
return t
170189

171-
async def __wait_until_released(self) -> None:
172-
while self.num_parallel_requests >= self.max_parallel:
173-
await asyncio.sleep(0.1)
190+
@staticmethod
191+
async def _cancel_workers(tasks):
192+
"""Cancel worker at the end"""
193+
log.debug(f"Canceling workers")
194+
try:
195+
# try to catch any potential exceptions
196+
await asyncio.gather(*tasks)
197+
except Exception as ex:
198+
# raise exceptions gathered from an failed worker
199+
raise ex
200+
finally:
201+
# cancel all tasks in both cases
202+
for task in tasks:
203+
task.cancel()
204+
# Wait until all worker tasks are cancelled.
205+
await asyncio.gather(*tasks, return_exceptions=True)
206+
log.debug(f"All workers canceled")
174207

175208
def __check_cancelled(self):
176209
if self._cancelled:
177210
raise EmbeddingCancelledException()
178211

179-
async def _encode_data_instance(
180-
self, data_instance: Any
181-
) -> Optional[bytes]:
212+
async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
182213
"""
183214
The reimplementation of this function must implement the procedure
184215
to encode the data item in a string format that will be sent to the
@@ -197,63 +228,73 @@ async def _encode_data_instance(
197228
raise NotImplementedError
198229

199230
async def _send_to_server(
200-
self,
201-
data_instance: Any,
202-
client: AsyncClient,
203-
proc_callback: Callable[[bool], None] = None,
204-
) -> Optional[List[float]]:
231+
self,
232+
client: AsyncClient,
233+
queue: asyncio.Queue,
234+
results: List,
235+
proc_callback: Callable[[bool], None] = None,
236+
):
205237
"""
206-
Function get an data instance. It extract data from it and send them to
207-
server and retrieve responses.
238+
Worker that embedds data. It is pulling items from the until the queue
239+
is empty. It is canceled by embedd_batch all tasks are finished
208240
209241
Parameters
210242
----------
211-
data_instance
212-
Single row of the input table.
213243
client
214244
HTTPX client that communicates with the server
245+
queue
246+
The queue with items of type TaskItem to be embedded
247+
results
248+
The list to append results in. The list has length equal to numbers
249+
of all items to embedd. The result need to be inserted at the index
250+
defined in queue items.
215251
proc_callback
216252
A function that is called after each item is fully processed
217253
by either getting a successful response from the server,
218254
getting the result from cache or skipping the item.
219-
220-
Returns
221-
-------
222-
Embedding. For items that are not successfully embedded returns None.
223255
"""
224-
await self.__wait_until_released()
225-
self.__check_cancelled()
226-
227-
self.num_parallel_requests += 1
228-
# load bytes
229-
data_bytes = await self._encode_data_instance(data_instance)
230-
if data_bytes is None:
231-
self.num_parallel_requests -= 1
232-
return None
256+
while not queue.empty():
257+
self.__check_cancelled()
258+
259+
# get item from the queue
260+
i, data_instance, num_repeats = await queue.get()
261+
262+
# load bytes
263+
data_bytes = await self._encode_data_instance(data_instance)
264+
if data_bytes is None:
265+
continue
233266

234-
# if data in cache return it
235-
cache_key = self._cache.md5_hash(data_bytes)
236-
emb = self._cache.get_cached_result_or_none(cache_key)
267+
# retrieve embedded item from the local cache
268+
cache_key = self._cache.md5_hash(data_bytes)
269+
log.debug(f"Embedding {cache_key}")
270+
emb = self._cache.get_cached_result_or_none(cache_key)
237271

238-
if emb is None:
239-
# in case that embedding not sucessfull resend it to the server
240-
# maximally for MAX_REPEATS time
241-
for i in range(1, self.MAX_REPEATS + 1):
242-
self.__check_cancelled()
272+
if emb is None:
273+
# send the item to the server for embedding if not in the local cache
274+
log.debug(f"Sending to the server: {cache_key}")
243275
url = (
244-
f"/{self.embedder_type}/{self._model}?"
245-
f"machine={self.machine_id}"
246-
f"&session={self.session_id}&retry={i}"
276+
f"/{self.embedder_type}/{self._model}?machine={self.machine_id}"
277+
f"&session={self.session_id}&retry={num_repeats+1}"
247278
)
248279
emb = await self._send_request(client, data_bytes, url)
249280
if emb is not None:
250281
self._cache.add(cache_key, emb)
251-
break # repeat only when embedding None
252-
if proc_callback:
253-
proc_callback(emb is not None)
254282

255-
self.num_parallel_requests -= 1
256-
return emb
283+
if emb is not None:
284+
# store result if embedding is successful
285+
log.debug(f"Successfully embedded: {cache_key}")
286+
results[i] = emb
287+
if proc_callback:
288+
proc_callback(emb is not None)
289+
elif num_repeats < self.MAX_REPEATS:
290+
log.debug(f"Not embedded successfully - reading to queue: {cache_key}")
291+
# if embedding not successful put the item to queue to be handled at
292+
# the end - the item is put to the end since it is possible that server
293+
# still process the request and the result will be in the cache later
294+
# repeating the request immediately may result in another fail when
295+
# processing takes longer
296+
queue.put_nowait(TaskItem(i, data_instance, no_repeats=num_repeats + 1))
297+
queue.task_done()
257298

258299
async def _send_request(
259300
self, client: AsyncClient, data: bytes, url: str
@@ -284,27 +325,23 @@ async def _send_request(
284325
response = await client.post(url, headers=headers, data=data)
285326
except ReadTimeout as ex:
286327
log.debug("Read timeout", exc_info=True)
287-
# it happens when server do not respond in 60 seconds, in
288-
# this case we return None and items will be resend later
328+
# it happens when server do not respond in time defined by timeout
329+
# return None and items will be resend later
289330

290331
# if it happens more than in ten consecutive cases it means
291332
# sth is wrong with embedder we stop embedding
292333
self.count_read_errors += 1
293-
294334
if self.count_read_errors >= self.max_errors:
295-
self.num_parallel_requests = 0 # for safety reasons
296335
raise EmbeddingConnectionError from ex
297336
return None
298337
except (OSError, NetworkError) as ex:
299338
log.debug("Network error", exc_info=True)
300-
# it happens when no connection and items cannot be sent to the
301-
# server
302-
# we count number of consecutive errors
339+
# it happens when no connection and items cannot be sent to server
340+
303341
# if more than 10 consecutive errors it means there is no
304342
# connection so we stop embedding with EmbeddingConnectionError
305343
self.count_connection_errors += 1
306344
if self.count_connection_errors >= self.max_errors:
307-
self.num_parallel_requests = 0 # for safety reasons
308345
raise EmbeddingConnectionError from ex
309346
return None
310347
except Exception:

Orange/misc/tests/test_server_embedder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,7 @@ def test_encode_data_instance(self):
167167
mocked_fun.assert_has_calls(
168168
[call(item) for item in self.test_data], any_order=True
169169
)
170+
171+
172+
if __name__ == "__main__":
173+
unittest.main()

0 commit comments

Comments
 (0)