Skip to content

Commit 9f5322a

Browse files
committed
server_embedder: modify callback to match others
1 parent 2d49cf8 commit 9f5322a

File tree

2 files changed

+74
-17
lines changed

2 files changed

+74
-17
lines changed

Orange/misc/server_embedder.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,24 @@
33
import logging
44
import random
55
import uuid
6+
import warnings
67
from collections import namedtuple
8+
from functools import partial
79
from json import JSONDecodeError
810
from os import getenv
911
from typing import Any, Callable, List, Optional
1012

1113
from AnyQt.QtCore import QSettings
1214
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
15+
from numpy import linspace
1316

1417
from Orange.misc.utils.embedder_utils import (
1518
EmbedderCache,
1619
EmbeddingCancelledException,
1720
EmbeddingConnectionError,
1821
get_proxies,
1922
)
23+
from Orange.util import dummy_callback
2024

2125
log = logging.getLogger(__name__)
2226
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))
@@ -59,8 +63,7 @@ def __init__(
5963
self._model = model_name
6064
self.embedder_type = embedder_type
6165

62-
# attribute that offers support for cancelling the embedding
63-
# if ran in another thread
66+
# remove in 3.33
6467
self._cancelled = False
6568

6669
self.machine_id = None
@@ -81,9 +84,11 @@ def __init__(
8184
self.content_type = None # need to be set in a class inheriting
8285

8386
def embedd_data(
84-
self,
85-
data: List[Any],
86-
processed_callback: Callable[[bool], None] = None,
87+
self,
88+
data: List[Any],
89+
processed_callback: Optional[Callable] = None,
90+
*,
91+
callback: Callable = dummy_callback,
8792
) -> List[Optional[List[float]]]:
8893
"""
8994
This function repeats calling embedding function until all items
@@ -95,9 +100,12 @@ def embedd_data(
95100
data
96101
List with data that needs to be embedded.
97102
processed_callback
103+
Deprecated: remove in 3.33
98104
A function that is called after each item is embedded
99105
by either getting a successful response from the server,
100106
getting the result from cache or skipping the item.
107+
callback
108+
Callback for reporting the progress in share of embedded items
101109
102110
Returns
103111
-------
@@ -119,15 +127,19 @@ def embedd_data(
119127
asyncio.set_event_loop(loop)
120128
try:
121129
embeddings = asyncio.get_event_loop().run_until_complete(
122-
self.embedd_batch(data, processed_callback)
130+
self.embedd_batch(data, processed_callback, callback=callback)
123131
)
124132
finally:
125133
loop.close()
126134

127135
return embeddings
128136

129137
async def embedd_batch(
130-
self, data: List[Any], proc_callback: Callable[[bool], None] = None
138+
self,
139+
data: List[Any],
140+
proc_callback: Optional[Callable] = None,
141+
*,
142+
callback: Callable = dummy_callback,
131143
) -> List[Optional[List[float]]]:
132144
"""
133145
Function perform embedding of a batch of data items.
@@ -136,10 +148,8 @@ async def embedd_batch(
136148
----------
137149
data
138150
A list of data that must be embedded.
139-
proc_callback
140-
A function that is called after each item is fully processed
141-
by either getting a successful response from the server,
142-
getting the result from cache or skipping the item.
151+
callback
152+
Callback for reporting the progress in share of embedded items
143153
144154
Returns
145155
-------
@@ -151,6 +161,21 @@ async def embedd_batch(
151161
EmbeddingCancelledException:
152162
If cancelled attribute is set to True (default=False).
153163
"""
164+
# in Orange 3.33 keep content of the if - remove if clause and complete else
165+
if proc_callback is None:
166+
progress_items = iter(linspace(0, 1, len(data)))
167+
168+
def success_callback():
169+
"""Callback called on every successful embedding"""
170+
callback(next(progress_items))
171+
else:
172+
warnings.warn(
173+
"proc_callback is deprecated and will be removed in version 3.33, "
174+
"use callback instead",
175+
FutureWarning,
176+
)
177+
success_callback = partial(proc_callback, True)
178+
154179
results = [None] * len(data)
155180
queue = asyncio.Queue()
156181

@@ -161,7 +186,7 @@ async def embedd_batch(
161186
async with AsyncClient(
162187
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
163188
) as client:
164-
tasks = self._init_workers(client, queue, results, proc_callback)
189+
tasks = self._init_workers(client, queue, results, success_callback)
165190

166191
# wait for the queue to complete or one of workers to exit
167192
queue_complete = asyncio.create_task(queue.join())
@@ -203,6 +228,7 @@ async def _cancel_workers(tasks):
203228
await asyncio.gather(*tasks, return_exceptions=True)
204229
log.debug("All workers canceled")
205230

231+
# remove in 3.33
206232
def __check_cancelled(self):
207233
if self._cancelled:
208234
raise EmbeddingCancelledException()
@@ -230,11 +256,11 @@ async def _send_to_server(
230256
client: AsyncClient,
231257
queue: asyncio.Queue,
232258
results: List,
233-
proc_callback: Callable[[bool], None] = None,
259+
proc_callback: Callable,
234260
):
235261
"""
236-
Worker that embedds data. It is pulling items from the until the queue
237-
is empty. It is canceled by embedd_batch all tasks are finished
262+
Worker that embedds data. It is pulling items from the queue until
263+
it is empty. It is runs until anything is in the queue, or it is canceled
238264
239265
Parameters
240266
----------
@@ -252,6 +278,7 @@ async def _send_to_server(
252278
getting the result from cache or skipping the item.
253279
"""
254280
while not queue.empty():
281+
# remove in 3.33
255282
self.__check_cancelled()
256283

257284
# get item from the queue
@@ -283,8 +310,7 @@ async def _send_to_server(
283310
# store result if embedding is successful
284311
log.debug("Successfully embedded: %s", cache_key)
285312
results[i] = emb
286-
if proc_callback:
287-
proc_callback(emb is not None)
313+
proc_callback()
288314
elif num_repeats < self.MAX_REPEATS:
289315
log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
290316
# if embedding not successful put the item to queue to be handled at
@@ -379,5 +405,11 @@ def _parse_response(response: Response) -> Optional[List[float]]:
379405
def clear_cache(self):
380406
self._cache.clear_cache()
381407

408+
# remove in 3.33
382409
def set_cancelled(self):
410+
warnings.warn(
411+
"set_cancelled is deprecated and will be removed in version 3.33, "
412+
"the process can be canceled by raising Error in callback",
413+
FutureWarning,
414+
)
383415
self._cancelled = True

Orange/misc/tests/test_server_embedder.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
from httpx import ReadTimeout
77

8+
import Orange
89
from Orange.data import Domain, StringVariable, Table
910
from Orange.misc.tests.example_embedder import ExampleServerEmbedder
1011

@@ -173,6 +174,30 @@ def test_retries(self, mock):
173174
self.embedder.embedd_data(self.test_data)
174175
self.assertEqual(len(self.test_data) * 3, mock.call_count)
175176

177+
@patch(_HTTPX_POST_METHOD, regular_dummy_sr)
178+
def test_callback(self):
179+
mock = MagicMock()
180+
self.embedder.embedd_data(self.test_data, callback=mock)
181+
182+
process_items = [call(x) for x in np.linspace(0, 1, len(self.test_data))]
183+
mock.assert_has_calls(process_items)
184+
185+
@patch(_HTTPX_POST_METHOD, regular_dummy_sr)
186+
def test_deprecated(self):
187+
"""
188+
When this start to fail:
189+
- remove process_callback parameter and marked places connected to this param
190+
- remove set_canceled and marked places connected to this method
191+
- this test
192+
"""
193+
self.assertGreaterEqual("3.33.0", Orange.__version__)
194+
195+
mock = MagicMock()
196+
self.embedder.embedd_data(self.test_data, processed_callback=mock)
197+
198+
process_items = [call(True) for _ in range(len(self.test_data))]
199+
mock.assert_has_calls(process_items)
200+
176201

177202
if __name__ == "__main__":
178203
unittest.main()

0 commit comments

Comments
 (0)