33import logging
44import random
55import uuid
6+ import warnings
67from collections import namedtuple
8+ from functools import partial
79from json import JSONDecodeError
810from os import getenv
911from typing import Any , Callable , List , Optional
1012
1113from AnyQt .QtCore import QSettings
1214from httpx import AsyncClient , NetworkError , ReadTimeout , Response
15+ from numpy import linspace
1316
1417from Orange .misc .utils .embedder_utils import (
1518 EmbedderCache ,
1619 EmbeddingCancelledException ,
1720 EmbeddingConnectionError ,
1821 get_proxies ,
1922)
23+ from Orange .util import dummy_callback
2024
2125log = logging .getLogger (__name__ )
2226TaskItem = namedtuple ("TaskItem" , ("id" , "item" , "no_repeats" ))
@@ -59,8 +63,7 @@ def __init__(
5963 self ._model = model_name
6064 self .embedder_type = embedder_type
6165
62- # attribute that offers support for cancelling the embedding
63- # if ran in another thread
66+ # remove in 3.33
6467 self ._cancelled = False
6568
6669 self .machine_id = None
@@ -81,9 +84,11 @@ def __init__(
8184 self .content_type = None # need to be set in a class inheriting
8285
8386 def embedd_data (
84- self ,
85- data : List [Any ],
86- processed_callback : Callable [[bool ], None ] = None ,
87+ self ,
88+ data : List [Any ],
89+ processed_callback : Optional [Callable ] = None ,
90+ * ,
91+ callback : Callable = dummy_callback ,
8792 ) -> List [Optional [List [float ]]]:
8893 """
8994 This function repeats calling embedding function until all items
@@ -95,9 +100,12 @@ def embedd_data(
95100 data
96101 List with data that needs to be embedded.
97102 processed_callback
103+ Deprecated: remove in 3.33
98104 A function that is called after each item is embedded
99105 by either getting a successful response from the server,
100106 getting the result from cache or skipping the item.
107+ callback
108+ Callback for reporting the progress in share of embedded items
101109
102110 Returns
103111 -------
@@ -119,15 +127,19 @@ def embedd_data(
119127 asyncio .set_event_loop (loop )
120128 try :
121129 embeddings = asyncio .get_event_loop ().run_until_complete (
122- self .embedd_batch (data , processed_callback )
130+ self .embedd_batch (data , processed_callback , callback = callback )
123131 )
124132 finally :
125133 loop .close ()
126134
127135 return embeddings
128136
129137 async def embedd_batch (
130- self , data : List [Any ], proc_callback : Callable [[bool ], None ] = None
138+ self ,
139+ data : List [Any ],
140+ proc_callback : Optional [Callable ] = None ,
141+ * ,
142+ callback : Callable = dummy_callback ,
131143 ) -> List [Optional [List [float ]]]:
132144 """
133145 Function perform embedding of a batch of data items.
@@ -136,10 +148,8 @@ async def embedd_batch(
136148 ----------
137149 data
138150 A list of data that must be embedded.
139- proc_callback
140- A function that is called after each item is fully processed
141- by either getting a successful response from the server,
142- getting the result from cache or skipping the item.
151+ callback
152+ Callback for reporting the progress in share of embedded items
143153
144154 Returns
145155 -------
@@ -151,6 +161,21 @@ async def embedd_batch(
151161 EmbeddingCancelledException:
152162 If cancelled attribute is set to True (default=False).
153163 """
164+ # in Orange 3.33 keep content of the if - remove if clause and complete else
165+ if proc_callback is None :
166+ progress_items = iter (linspace (0 , 1 , len (data )))
167+
168+ def success_callback ():
169+ """Callback called on every successful embedding"""
170+ callback (next (progress_items ))
171+ else :
172+ warnings .warn (
173+ "proc_callback is deprecated and will be removed in version 3.33, "
174+ "use callback instead" ,
175+ FutureWarning ,
176+ )
177+ success_callback = partial (proc_callback , True )
178+
154179 results = [None ] * len (data )
155180 queue = asyncio .Queue ()
156181
@@ -161,7 +186,7 @@ async def embedd_batch(
161186 async with AsyncClient (
162187 timeout = self .timeout , base_url = self .server_url , proxies = get_proxies ()
163188 ) as client :
164- tasks = self ._init_workers (client , queue , results , proc_callback )
189+ tasks = self ._init_workers (client , queue , results , success_callback )
165190
166191 # wait for the queue to complete or one of workers to exit
167192 queue_complete = asyncio .create_task (queue .join ())
@@ -203,6 +228,7 @@ async def _cancel_workers(tasks):
203228 await asyncio .gather (* tasks , return_exceptions = True )
204229 log .debug ("All workers canceled" )
205230
231+ # remove in 3.33
206232 def __check_cancelled (self ):
207233 if self ._cancelled :
208234 raise EmbeddingCancelledException ()
@@ -230,11 +256,11 @@ async def _send_to_server(
230256 client : AsyncClient ,
231257 queue : asyncio .Queue ,
232258 results : List ,
233- proc_callback : Callable [[ bool ], None ] = None ,
259+ proc_callback : Callable ,
234260 ):
235261 """
236- Worker that embedds data. It is pulling items from the until the queue
237- is empty. It is canceled by embedd_batch all tasks are finished
262+ Worker that embedds data. It is pulling items from the queue until
263+ it is empty. It is runs until anything is in the queue, or it is canceled
238264
239265 Parameters
240266 ----------
@@ -252,6 +278,7 @@ async def _send_to_server(
252278 getting the result from cache or skipping the item.
253279 """
254280 while not queue .empty ():
281+ # remove in 3.33
255282 self .__check_cancelled ()
256283
257284 # get item from the queue
@@ -283,8 +310,7 @@ async def _send_to_server(
283310 # store result if embedding is successful
284311 log .debug ("Successfully embedded: %s" , cache_key )
285312 results [i ] = emb
286- if proc_callback :
287- proc_callback (emb is not None )
313+ proc_callback ()
288314 elif num_repeats < self .MAX_REPEATS :
289315 log .debug ("Embedding unsuccessful - reading to queue: %s" , cache_key )
290316 # if embedding not successful put the item to queue to be handled at
@@ -379,5 +405,11 @@ def _parse_response(response: Response) -> Optional[List[float]]:
379405 def clear_cache (self ):
380406 self ._cache .clear_cache ()
381407
408+ # remove in 3.33
382409 def set_cancelled (self ):
410+ warnings .warn (
411+ "set_cancelled is deprecated and will be removed in version 3.33, "
412+ "the process can be canceled by raising Error in callback" ,
413+ FutureWarning ,
414+ )
383415 self ._cancelled = True
0 commit comments