Skip to content

Commit f7cd683

Browse files
committed
Server embedder: use queue, handle unsuccessful requests at the end
1 parent c38b96f commit f7cd683

File tree

2 files changed

+125
-80
lines changed

2 files changed

+125
-80
lines changed

Orange/misc/server_embedder.py

Lines changed: 116 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,30 @@
33
import logging
44
import random
55
import uuid
6+
from collections import namedtuple
67
from json import JSONDecodeError
78
from os import getenv
89
from typing import Any, Callable, List, Optional
910

1011
from AnyQt.QtCore import QSettings
1112
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
1213

13-
from Orange.misc.utils.embedder_utils import (EmbedderCache,
14-
EmbeddingCancelledException,
15-
EmbeddingConnectionError,
16-
get_proxies)
14+
from Orange.misc.utils.embedder_utils import (
15+
EmbedderCache,
16+
EmbeddingCancelledException,
17+
EmbeddingConnectionError,
18+
get_proxies,
19+
)
1720

1821
log = logging.getLogger(__name__)
22+
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))
1923

2024

2125
class ServerEmbedderCommunicator:
2226
"""
2327
This class needs to be inherited by the class which re-implements
24-
_encode_data_instance and defines self.content_type. For sending a table
25-
with data items use embedd_table function. This one is called with the
26-
complete Orange data Table. Then _encode_data_instance needs to extract
27-
data to be embedded from the RowInstance. For images, it takes the image
28-
path from the table, load image, and transform it into bytes.
28+
_encode_data_instance and defines self.content_type. For sending a list
29+
with data items use embedd_table function.
2930
3031
Attributes
3132
----------
@@ -69,14 +70,14 @@ def __init__(
6970
) or str(uuid.getnode())
7071
except TypeError:
7172
self.machine_id = str(uuid.getnode())
72-
self.session_id = str(random.randint(1, 1e10))
73+
self.session_id = str(random.randint(1, int(1e10)))
7374

7475
self._cache = EmbedderCache(model_name)
7576

7677
# default embedding timeouts are too small we need to increase them
7778
self.timeout = 180
78-
self.num_parallel_requests = 0
79-
self.max_parallel = max_parallel_requests
79+
self.max_parallel_requests = max_parallel_requests
80+
8081
self.content_type = None # need to be set in a class inheriting
8182

8283
def embedd_data(
@@ -111,8 +112,7 @@ def embedd_data(
111112
EmbeddingCancelledException:
112113
If cancelled attribute is set to True (default=False).
113114
"""
114-
# if there is less items than 10 connection error should be raised
115-
# earlier
115+
# if there is less items than 10 connection error should be raised earlier
116116
self.max_errors = min(len(data) * self.MAX_REPEATS, 10)
117117

118118
loop = asyncio.new_event_loop()
@@ -121,11 +121,9 @@ def embedd_data(
121121
embeddings = asyncio.get_event_loop().run_until_complete(
122122
self.embedd_batch(data, processed_callback)
123123
)
124-
except Exception:
124+
finally:
125125
loop.close()
126-
raise
127126

128-
loop.close()
129127
return embeddings
130128

131129
async def embedd_batch(
@@ -153,32 +151,63 @@ async def embedd_batch(
153151
EmbeddingCancelledException:
154152
If cancelled attribute is set to True (default=False).
155153
"""
156-
requests = []
154+
results = [None] * len(data)
155+
queue = asyncio.Queue()
156+
157+
# fill the queue with items to embedd
158+
for i, item in enumerate(data):
159+
queue.put_nowait(TaskItem(id=i, item=item, no_repeats=0))
160+
157161
async with AsyncClient(
158-
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
162+
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
159163
) as client:
160-
for p in data:
161-
if self._cancelled:
162-
raise EmbeddingCancelledException()
163-
requests.append(self._send_to_server(p, client, proc_callback))
164+
tasks = self._init_workers(client, queue, results, proc_callback)
165+
166+
# wait for the queue to complete or one of workers to exit
167+
queue_complete = asyncio.create_task(queue.join())
168+
await asyncio.wait(
169+
[queue_complete, *tasks], return_when=asyncio.FIRST_COMPLETED
170+
)
171+
172+
# Cancel worker tasks when done
173+
queue_complete.cancel()
174+
await self._cancel_workers(tasks)
164175

165-
embeddings = await asyncio.gather(*requests)
166176
self._cache.persist_cache()
167-
assert self.num_parallel_requests == 0
177+
return results
168178

169-
return embeddings
179+
def _init_workers(self, client, queue, results, callback):
180+
"""Init required number of workers"""
181+
t = [
182+
asyncio.create_task(self._send_to_server(client, queue, results, callback))
183+
for _ in range(self.max_parallel_requests)
184+
]
185+
log.debug("Created %d workers", self.max_parallel_requests)
186+
return t
170187

171-
async def __wait_until_released(self) -> None:
172-
while self.num_parallel_requests >= self.max_parallel:
173-
await asyncio.sleep(0.1)
188+
@staticmethod
189+
async def _cancel_workers(tasks):
190+
"""Cancel worker at the end"""
191+
log.debug("Canceling workers")
192+
try:
193+
# try to catch any potential exceptions
194+
await asyncio.gather(*tasks)
195+
except Exception as ex:
196+
# raise exceptions gathered from an failed worker
197+
raise ex
198+
finally:
199+
# cancel all tasks in both cases
200+
for task in tasks:
201+
task.cancel()
202+
# Wait until all worker tasks are cancelled.
203+
await asyncio.gather(*tasks, return_exceptions=True)
204+
log.debug("All workers canceled")
174205

175206
def __check_cancelled(self):
176207
if self._cancelled:
177208
raise EmbeddingCancelledException()
178209

179-
async def _encode_data_instance(
180-
self, data_instance: Any
181-
) -> Optional[bytes]:
210+
async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
182211
"""
183212
The reimplementation of this function must implement the procedure
184213
to encode the data item in a string format that will be sent to the
@@ -197,63 +226,74 @@ async def _encode_data_instance(
197226
raise NotImplementedError
198227

199228
async def _send_to_server(
200-
self,
201-
data_instance: Any,
202-
client: AsyncClient,
203-
proc_callback: Callable[[bool], None] = None,
204-
) -> Optional[List[float]]:
229+
self,
230+
client: AsyncClient,
231+
queue: asyncio.Queue,
232+
results: List,
233+
proc_callback: Callable[[bool], None] = None,
234+
):
205235
"""
206-
Function get an data instance. It extract data from it and send them to
207-
server and retrieve responses.
236+
Worker that embedds data. It is pulling items from the until the queue
237+
is empty. It is canceled by embedd_batch all tasks are finished
208238
209239
Parameters
210240
----------
211-
data_instance
212-
Single row of the input table.
213241
client
214242
HTTPX client that communicates with the server
243+
queue
244+
The queue with items of type TaskItem to be embedded
245+
results
246+
The list to append results in. The list has length equal to numbers
247+
of all items to embedd. The result need to be inserted at the index
248+
defined in queue items.
215249
proc_callback
216250
A function that is called after each item is fully processed
217251
by either getting a successful response from the server,
218252
getting the result from cache or skipping the item.
219-
220-
Returns
221-
-------
222-
Embedding. For items that are not successfully embedded returns None.
223253
"""
224-
await self.__wait_until_released()
225-
self.__check_cancelled()
226-
227-
self.num_parallel_requests += 1
228-
# load bytes
229-
data_bytes = await self._encode_data_instance(data_instance)
230-
if data_bytes is None:
231-
self.num_parallel_requests -= 1
232-
return None
233-
234-
# if data in cache return it
235-
cache_key = self._cache.md5_hash(data_bytes)
236-
emb = self._cache.get_cached_result_or_none(cache_key)
237-
238-
if emb is None:
239-
# in case that embedding not sucessfull resend it to the server
240-
# maximally for MAX_REPEATS time
241-
for i in range(1, self.MAX_REPEATS + 1):
242-
self.__check_cancelled()
254+
while not queue.empty():
255+
self.__check_cancelled()
256+
257+
# get item from the queue
258+
i, data_instance, num_repeats = await queue.get()
259+
num_repeats += 1
260+
261+
# load bytes
262+
data_bytes = await self._encode_data_instance(data_instance)
263+
if data_bytes is None:
264+
continue
265+
266+
# retrieve embedded item from the local cache
267+
cache_key = self._cache.md5_hash(data_bytes)
268+
log.debug("Embedding %s", cache_key)
269+
emb = self._cache.get_cached_result_or_none(cache_key)
270+
271+
if emb is None:
272+
# send the item to the server for embedding if not in the local cache
273+
log.debug("Sending to the server: %s", cache_key)
243274
url = (
244-
f"/{self.embedder_type}/{self._model}?"
245-
f"machine={self.machine_id}"
246-
f"&session={self.session_id}&retry={i}"
275+
f"/{self.embedder_type}/{self._model}?machine={self.machine_id}"
276+
f"&session={self.session_id}&retry={num_repeats}"
247277
)
248278
emb = await self._send_request(client, data_bytes, url)
249279
if emb is not None:
250280
self._cache.add(cache_key, emb)
251-
break # repeat only when embedding None
252-
if proc_callback:
253-
proc_callback(emb is not None)
254281

255-
self.num_parallel_requests -= 1
256-
return emb
282+
if emb is not None:
283+
# store result if embedding is successful
284+
log.debug("Successfully embedded: %s", cache_key)
285+
results[i] = emb
286+
if proc_callback:
287+
proc_callback(emb is not None)
288+
elif num_repeats < self.MAX_REPEATS:
289+
log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
290+
# if embedding not successful put the item to queue to be handled at
291+
# the end - the item is put to the end since it is possible that server
292+
# still process the request and the result will be in the cache later
293+
# repeating the request immediately may result in another fail when
294+
# processing takes longer
295+
queue.put_nowait(TaskItem(i, data_instance, no_repeats=num_repeats))
296+
queue.task_done()
257297

258298
async def _send_request(
259299
self, client: AsyncClient, data: bytes, url: str
@@ -284,27 +324,23 @@ async def _send_request(
284324
response = await client.post(url, headers=headers, data=data)
285325
except ReadTimeout as ex:
286326
log.debug("Read timeout", exc_info=True)
287-
# it happens when server do not respond in 60 seconds, in
288-
# this case we return None and items will be resend later
327+
# it happens when server do not respond in time defined by timeout
328+
# return None and items will be resend later
289329

290330
# if it happens more than in ten consecutive cases it means
291331
# sth is wrong with embedder we stop embedding
292332
self.count_read_errors += 1
293-
294333
if self.count_read_errors >= self.max_errors:
295-
self.num_parallel_requests = 0 # for safety reasons
296334
raise EmbeddingConnectionError from ex
297335
return None
298336
except (OSError, NetworkError) as ex:
299337
log.debug("Network error", exc_info=True)
300-
# it happens when no connection and items cannot be sent to the
301-
# server
302-
# we count number of consecutive errors
338+
# it happens when no connection and items cannot be sent to server
339+
303340
# if more than 10 consecutive errors it means there is no
304341
# connection so we stop embedding with EmbeddingConnectionError
305342
self.count_connection_errors += 1
306343
if self.count_connection_errors >= self.max_errors:
307-
self.num_parallel_requests = 0 # for safety reasons
308344
raise EmbeddingConnectionError from ex
309345
return None
310346
except Exception:

Orange/misc/tests/test_server_embedder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,12 @@ def test_encode_data_instance(self):
167167
mocked_fun.assert_has_calls(
168168
[call(item) for item in self.test_data], any_order=True
169169
)
170+
171+
@patch(_HTTPX_POST_METHOD, return_value=DummyResponse(b''), new_callable=AsyncMock)
172+
def test_retries(self, mock):
173+
self.embedder.embedd_data(self.test_data)
174+
self.assertEqual(len(self.test_data) * 3, mock.call_count)
175+
176+
177+
if __name__ == "__main__":
178+
unittest.main()

0 commit comments

Comments
 (0)