Skip to content

Commit a9bafc3

Browse files
committed
server_embedder: modify callback to match others
1 parent 04547a5 commit a9bafc3

File tree

2 files changed

+62
-13
lines changed

2 files changed

+62
-13
lines changed

Orange/misc/server_embedder.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33
import logging
44
import random
55
import uuid
6+
import warnings
67
from collections import namedtuple
78
from json import JSONDecodeError
89
from os import getenv
910
from typing import Any, Callable, List, Optional
1011

1112
from AnyQt.QtCore import QSettings
1213
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
14+
from numpy import linspace
1315

1416
from Orange.misc.utils.embedder_utils import (
1517
EmbedderCache,
1618
EmbeddingCancelledException,
1719
EmbeddingConnectionError,
1820
get_proxies,
1921
)
22+
from Orange.util import dummy_callback
2023

2124
log = logging.getLogger(__name__)
2225
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))
@@ -59,8 +62,7 @@ def __init__(
5962
self._model = model_name
6063
self.embedder_type = embedder_type
6164

62-
# attribute that offers support for cancelling the embedding
63-
# if ran in another thread
65+
# remove in 3.33
6466
self._cancelled = False
6567

6668
self.machine_id = None
@@ -81,9 +83,10 @@ def __init__(
8183
self.content_type = None # need to be set in a class inheriting
8284

8385
def embedd_data(
84-
self,
85-
data: List[Any],
86-
processed_callback: Callable[[bool], None] = None,
86+
self,
87+
data: List[Any],
88+
processed_callback: Optional[Callable] = None,
89+
callback: Callable = dummy_callback,
8790
) -> List[Optional[List[float]]]:
8891
"""
8992
This function repeats calling embedding function until all items
@@ -95,9 +98,12 @@ def embedd_data(
9598
data
9699
List with data that needs to be embedded.
97100
processed_callback
101+
Deprecated: remove in 3.33
98102
A function that is called after each item is embedded
99103
by either getting a successful response from the server,
100104
getting the result from cache or skipping the item.
105+
callback
106+
Callback for reporting the progress in share of embedded items
101107
102108
Returns
103109
-------
@@ -119,15 +125,18 @@ def embedd_data(
119125
asyncio.set_event_loop(loop)
120126
try:
121127
embeddings = asyncio.get_event_loop().run_until_complete(
122-
self.embedd_batch(data, processed_callback)
128+
self.embedd_batch(data, processed_callback, callback)
123129
)
124130
finally:
125131
loop.close()
126132

127133
return embeddings
128134

129135
async def embedd_batch(
130-
self, data: List[Any], proc_callback: Callable[[bool], None] = None
136+
self,
137+
data: List[Any],
138+
processed_calback: Optional[Callable] = None,
139+
callback: Callable = dummy_callback,
131140
) -> List[Optional[List[float]]]:
132141
"""
133142
Function perform embedding of a batch of data items.
@@ -136,10 +145,8 @@ async def embedd_batch(
136145
----------
137146
data
138147
A list of data that must be embedded.
139-
proc_callback
140-
A function that is called after each item is fully processed
141-
by either getting a successful response from the server,
142-
getting the result from cache or skipping the item.
148+
callback
149+
Callback for reporting the progress in share of embedded items
143150
144151
Returns
145152
-------
@@ -151,6 +158,22 @@ async def embedd_batch(
151158
EmbeddingCancelledException:
152159
If cancelled attribute is set to True (default=False).
153160
"""
161+
# in Orange 3.33 keep content of the if - remove if clause and complete else
162+
if processed_calback is None:
163+
progress_items = iter(linspace(0, 1, len(data)))
164+
165+
def success_callback():
166+
"""Callback called on every successful embedding"""
167+
callback(next(progress_items))
168+
169+
else:
170+
warnings.warn(
171+
"process_callback is deprecated and will be removed in version 3.33, "
172+
"use callback instead",
173+
FutureWarning,
174+
)
175+
success_callback = processed_calback
176+
154177
results = [None] * len(data)
155178
queue = asyncio.Queue()
156179

@@ -161,7 +184,7 @@ async def embedd_batch(
161184
async with AsyncClient(
162185
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
163186
) as client:
164-
tasks = self._init_workers(client, queue, results, proc_callback)
187+
tasks = self._init_workers(client, queue, results, success_callback)
165188

166189
# wait for the queue to complete or one of workers to exit
167190
queue_complete = asyncio.create_task(queue.join())
@@ -203,6 +226,7 @@ async def _cancel_workers(tasks):
203226
await asyncio.gather(*tasks, return_exceptions=True)
204227
log.debug("All workers canceled")
205228

229+
# remove in 3.33
206230
def __check_cancelled(self):
207231
if self._cancelled:
208232
raise EmbeddingCancelledException()
@@ -252,6 +276,7 @@ async def _send_to_server(
252276
getting the result from cache or skipping the item.
253277
"""
254278
while not queue.empty():
279+
# remove in 3.33
255280
self.__check_cancelled()
256281

257282
# get item from the queue
@@ -284,7 +309,7 @@ async def _send_to_server(
284309
log.debug("Successfully embedded: %s", cache_key)
285310
results[i] = emb
286311
if proc_callback:
287-
proc_callback(emb is not None)
312+
proc_callback()
288313
elif num_repeats < self.MAX_REPEATS:
289314
log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
290315
# if embedding not successful put the item to queue to be handled at
@@ -379,5 +404,11 @@ def _parse_response(response: Response) -> Optional[List[float]]:
379404
def clear_cache(self):
380405
self._cache.clear_cache()
381406

407+
# remove in 3.33
382408
def set_cancelled(self):
409+
warnings.warn(
410+
"set_cancelled is deprecated and will be removed in version 3.33, "
411+
"the process can be canceled by raising Error in callback",
412+
FutureWarning,
413+
)
383414
self._cancelled = True

Orange/misc/tests/test_server_embedder.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
from httpx import ReadTimeout
77

8+
import Orange
89
from Orange.data import Domain, StringVariable, Table
910
from Orange.misc.tests.example_embedder import ExampleServerEmbedder
1011

@@ -173,6 +174,23 @@ def test_retries(self, mock):
173174
self.embedder.embedd_data(self.test_data)
174175
self.assertEqual(len(self.test_data) * 3, mock.call_count)
175176

177+
@patch(_HTTPX_POST_METHOD, regular_dummy_sr)
178+
def test_callback(self):
179+
mock = MagicMock()
180+
self.embedder.embedd_data(self.test_data, callback=mock)
181+
182+
process_items = [call(x) for x in np.linspace(0, 1, len(self.test_data))]
183+
mock.assert_has_calls(process_items)
184+
185+
def test_deprecated(self):
186+
"""
187+
When this start to fail:
188+
- remove process_callback parameter and marked places connected to this param
189+
- remove set_canceled and marked places connected to this method
190+
- this test
191+
"""
192+
self.assertGreaterEqual("3.33.0", Orange.__version__)
193+
176194

177195
if __name__ == "__main__":
178196
unittest.main()

0 commit comments

Comments
 (0)