Skip to content

Commit a4f8146

Browse files
committed
Server embedder: use queue, handle unsuccessful requests at the end
1 parent c38b96f commit a4f8146

File tree

2 files changed

+126
-79
lines changed

2 files changed

+126
-79
lines changed

Orange/misc/server_embedder.py

Lines changed: 117 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,30 @@
33
import logging
44
import random
55
import uuid
6+
from collections import namedtuple
67
from json import JSONDecodeError
78
from os import getenv
89
from typing import Any, Callable, List, Optional
910

1011
from AnyQt.QtCore import QSettings
1112
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
1213

13-
from Orange.misc.utils.embedder_utils import (EmbedderCache,
14-
EmbeddingCancelledException,
15-
EmbeddingConnectionError,
16-
get_proxies)
14+
from Orange.misc.utils.embedder_utils import (
15+
EmbedderCache,
16+
EmbeddingCancelledException,
17+
EmbeddingConnectionError,
18+
get_proxies,
19+
)
1720

1821
log = logging.getLogger(__name__)
22+
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))
1923

2024

2125
class ServerEmbedderCommunicator:
2226
"""
2327
This class needs to be inherited by the class which re-implements
24-
_encode_data_instance and defines self.content_type. For sending a table
25-
with data items use embedd_table function. This one is called with the
26-
complete Orange data Table. Then _encode_data_instance needs to extract
27-
data to be embedded from the RowInstance. For images, it takes the image
28-
path from the table, load image, and transform it into bytes.
28+
_encode_data_instance and defines self.content_type. For sending a list
29+
with data items use embedd_table function.
2930
3031
Attributes
3132
----------
@@ -69,14 +70,14 @@ def __init__(
6970
) or str(uuid.getnode())
7071
except TypeError:
7172
self.machine_id = str(uuid.getnode())
72-
self.session_id = str(random.randint(1, 1e10))
73+
self.session_id = str(random.randint(1, int(1e10)))
7374

7475
self._cache = EmbedderCache(model_name)
7576

7677
# default embedding timeouts are too small we need to increase them
7778
self.timeout = 180
78-
self.num_parallel_requests = 0
79-
self.max_parallel = max_parallel_requests
79+
self.max_parallel_requests = max_parallel_requests
80+
8081
self.content_type = None # need to be set in a class inheriting
8182

8283
def embedd_data(
@@ -111,8 +112,7 @@ def embedd_data(
111112
EmbeddingCancelledException:
112113
If cancelled attribute is set to True (default=False).
113114
"""
114-
# if there is less items than 10 connection error should be raised
115-
# earlier
115+
# if there is less items than 10 connection error should be raised earlier
116116
self.max_errors = min(len(data) * self.MAX_REPEATS, 10)
117117

118118
loop = asyncio.new_event_loop()
@@ -122,10 +122,10 @@ def embedd_data(
122122
self.embedd_batch(data, processed_callback)
123123
)
124124
except Exception:
125-
loop.close()
126125
raise
126+
finally:
127+
loop.close()
127128

128-
loop.close()
129129
return embeddings
130130

131131
async def embedd_batch(
@@ -153,32 +153,63 @@ async def embedd_batch(
153153
EmbeddingCancelledException:
154154
If cancelled attribute is set to True (default=False).
155155
"""
156-
requests = []
156+
results = [None] * len(data)
157+
queue = asyncio.Queue()
158+
159+
# fill the queue with items to embedd
160+
for i, item in enumerate(data):
161+
queue.put_nowait(TaskItem(id=i, item=item, no_repeats=0))
162+
157163
async with AsyncClient(
158-
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
164+
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
159165
) as client:
160-
for p in data:
161-
if self._cancelled:
162-
raise EmbeddingCancelledException()
163-
requests.append(self._send_to_server(p, client, proc_callback))
166+
tasks = self._init_workers(client, queue, results, proc_callback)
167+
168+
# wait for the queue to complete or one of workers to exit
169+
queue_complete = asyncio.create_task(queue.join())
170+
await asyncio.wait(
171+
[queue_complete, *tasks], return_when=asyncio.FIRST_COMPLETED
172+
)
173+
174+
# Cancel worker tasks when done
175+
queue_complete.cancel()
176+
await self._cancel_workers(tasks)
164177

165-
embeddings = await asyncio.gather(*requests)
166178
self._cache.persist_cache()
167-
assert self.num_parallel_requests == 0
179+
return results
168180

169-
return embeddings
181+
def _init_workers(self, client, queue, results, callback):
182+
"""Init required number of workers"""
183+
t = [
184+
asyncio.create_task(self._send_to_server(client, queue, results, callback))
185+
for _ in range(self.max_parallel_requests)
186+
]
187+
log.debug(f"Created {self.max_parallel_requests} workers")
188+
return t
170189

171-
async def __wait_until_released(self) -> None:
172-
while self.num_parallel_requests >= self.max_parallel:
173-
await asyncio.sleep(0.1)
190+
@staticmethod
191+
async def _cancel_workers(tasks):
192+
"""Cancel worker at the end"""
193+
log.debug(f"Canceling workers")
194+
try:
195+
# try to catch any potential exceptions
196+
await asyncio.gather(*tasks)
197+
except Exception as ex:
198+
# raise exceptions gathered from an failed worker
199+
raise ex
200+
finally:
201+
# cancel all tasks in both cases
202+
for task in tasks:
203+
task.cancel()
204+
# Wait until all worker tasks are cancelled.
205+
await asyncio.gather(*tasks, return_exceptions=True)
206+
log.debug(f"All workers canceled")
174207

175208
def __check_cancelled(self):
176209
if self._cancelled:
177210
raise EmbeddingCancelledException()
178211

179-
async def _encode_data_instance(
180-
self, data_instance: Any
181-
) -> Optional[bytes]:
212+
async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
182213
"""
183214
The reimplementation of this function must implement the procedure
184215
to encode the data item in a string format that will be sent to the
@@ -197,63 +228,74 @@ async def _encode_data_instance(
197228
raise NotImplementedError
198229

199230
async def _send_to_server(
200-
self,
201-
data_instance: Any,
202-
client: AsyncClient,
203-
proc_callback: Callable[[bool], None] = None,
204-
) -> Optional[List[float]]:
231+
self,
232+
client: AsyncClient,
233+
queue: asyncio.Queue,
234+
results: List,
235+
proc_callback: Callable[[bool], None] = None,
236+
):
205237
"""
206-
Function get an data instance. It extract data from it and send them to
207-
server and retrieve responses.
238+
Worker that embedds data. It is pulling items from the until the queue
239+
is empty. It is canceled by embedd_batch all tasks are finished
208240
209241
Parameters
210242
----------
211-
data_instance
212-
Single row of the input table.
213243
client
214244
HTTPX client that communicates with the server
245+
queue
246+
The queue with items of type TaskItem to be embedded
247+
results
248+
The list to append results in. The list has length equal to numbers
249+
of all items to embedd. The result need to be inserted at the index
250+
defined in queue items.
215251
proc_callback
216252
A function that is called after each item is fully processed
217253
by either getting a successful response from the server,
218254
getting the result from cache or skipping the item.
219-
220-
Returns
221-
-------
222-
Embedding. For items that are not successfully embedded returns None.
223255
"""
224-
await self.__wait_until_released()
225-
self.__check_cancelled()
226-
227-
self.num_parallel_requests += 1
228-
# load bytes
229-
data_bytes = await self._encode_data_instance(data_instance)
230-
if data_bytes is None:
231-
self.num_parallel_requests -= 1
232-
return None
233-
234-
# if data in cache return it
235-
cache_key = self._cache.md5_hash(data_bytes)
236-
emb = self._cache.get_cached_result_or_none(cache_key)
237-
238-
if emb is None:
239-
# in case that embedding not sucessfull resend it to the server
240-
# maximally for MAX_REPEATS time
241-
for i in range(1, self.MAX_REPEATS + 1):
242-
self.__check_cancelled()
256+
while not queue.empty():
257+
self.__check_cancelled()
258+
259+
# get item from the queue
260+
i, data_instance, num_repeats = await queue.get()
261+
num_repeats += 1
262+
263+
# load bytes
264+
data_bytes = await self._encode_data_instance(data_instance)
265+
if data_bytes is None:
266+
continue
267+
268+
# retrieve embedded item from the local cache
269+
cache_key = self._cache.md5_hash(data_bytes)
270+
log.debug(f"Embedding {cache_key}")
271+
emb = self._cache.get_cached_result_or_none(cache_key)
272+
273+
if emb is None:
274+
# send the item to the server for embedding if not in the local cache
275+
log.debug(f"Sending to the server: {cache_key}")
243276
url = (
244-
f"/{self.embedder_type}/{self._model}?"
245-
f"machine={self.machine_id}"
246-
f"&session={self.session_id}&retry={i}"
277+
f"/{self.embedder_type}/{self._model}?machine={self.machine_id}"
278+
f"&session={self.session_id}&retry={num_repeats}"
247279
)
248280
emb = await self._send_request(client, data_bytes, url)
249281
if emb is not None:
250282
self._cache.add(cache_key, emb)
251-
break # repeat only when embedding None
252-
if proc_callback:
253-
proc_callback(emb is not None)
254283

255-
self.num_parallel_requests -= 1
256-
return emb
284+
if emb is not None:
285+
# store result if embedding is successful
286+
log.debug(f"Successfully embedded: {cache_key}")
287+
results[i] = emb
288+
if proc_callback:
289+
proc_callback(emb is not None)
290+
elif num_repeats < self.MAX_REPEATS:
291+
log.debug(f"Not embedded successfully - reading to queue: {cache_key}")
292+
# if embedding not successful put the item to queue to be handled at
293+
# the end - the item is put to the end since it is possible that server
294+
# still process the request and the result will be in the cache later
295+
# repeating the request immediately may result in another fail when
296+
# processing takes longer
297+
queue.put_nowait(TaskItem(i, data_instance, no_repeats=num_repeats))
298+
queue.task_done()
257299

258300
async def _send_request(
259301
self, client: AsyncClient, data: bytes, url: str
@@ -284,27 +326,23 @@ async def _send_request(
284326
response = await client.post(url, headers=headers, data=data)
285327
except ReadTimeout as ex:
286328
log.debug("Read timeout", exc_info=True)
287-
# it happens when server do not respond in 60 seconds, in
288-
# this case we return None and items will be resend later
329+
# it happens when server do not respond in time defined by timeout
330+
# return None and items will be resend later
289331

290332
# if it happens more than in ten consecutive cases it means
291333
# sth is wrong with embedder we stop embedding
292334
self.count_read_errors += 1
293-
294335
if self.count_read_errors >= self.max_errors:
295-
self.num_parallel_requests = 0 # for safety reasons
296336
raise EmbeddingConnectionError from ex
297337
return None
298338
except (OSError, NetworkError) as ex:
299339
log.debug("Network error", exc_info=True)
300-
# it happens when no connection and items cannot be sent to the
301-
# server
302-
# we count number of consecutive errors
340+
# it happens when no connection and items cannot be sent to server
341+
303342
# if more than 10 consecutive errors it means there is no
304343
# connection so we stop embedding with EmbeddingConnectionError
305344
self.count_connection_errors += 1
306345
if self.count_connection_errors >= self.max_errors:
307-
self.num_parallel_requests = 0 # for safety reasons
308346
raise EmbeddingConnectionError from ex
309347
return None
310348
except Exception:

Orange/misc/tests/test_server_embedder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,12 @@ def test_encode_data_instance(self):
167167
mocked_fun.assert_has_calls(
168168
[call(item) for item in self.test_data], any_order=True
169169
)
170+
171+
@patch(_HTTPX_POST_METHOD, return_value=DummyResponse(b''))
172+
def test_retries(self, mock):
173+
self.embedder.embedd_data(self.test_data)
174+
self.assertEqual(len(self.test_data) * 3, mock.call_count)
175+
176+
177+
if __name__ == "__main__":
178+
unittest.main()

0 commit comments

Comments
 (0)