33import logging
44import random
55import uuid
6+ import warnings
67from collections import namedtuple
78from json import JSONDecodeError
89from os import getenv
910from typing import Any , Callable , List , Optional
1011
1112from AnyQt .QtCore import QSettings
1213from httpx import AsyncClient , NetworkError , ReadTimeout , Response
14+ from numpy import linspace
1315
1416from Orange .misc .utils .embedder_utils import (
1517 EmbedderCache ,
1618 EmbeddingCancelledException ,
1719 EmbeddingConnectionError ,
1820 get_proxies ,
1921)
22+ from Orange .util import dummy_callback
2023
2124log = logging .getLogger (__name__ )
2225TaskItem = namedtuple ("TaskItem" , ("id" , "item" , "no_repeats" ))
@@ -59,8 +62,7 @@ def __init__(
5962 self ._model = model_name
6063 self .embedder_type = embedder_type
6164
62- # attribute that offers support for cancelling the embedding
63- # if ran in another thread
65+ # remove in 3.33
6466 self ._cancelled = False
6567
6668 self .machine_id = None
@@ -81,9 +83,10 @@ def __init__(
8183 self .content_type = None # need to be set in a class inheriting
8284
8385 def embedd_data (
84- self ,
85- data : List [Any ],
86- processed_callback : Callable [[bool ], None ] = None ,
86+ self ,
87+ data : List [Any ],
88+ processed_callback : Optional [Callable ] = None ,
89+ callback : Callable = dummy_callback ,
8790 ) -> List [Optional [List [float ]]]:
8891 """
8992 This function repeats calling embedding function until all items
@@ -95,9 +98,12 @@ def embedd_data(
9598 data
9699 List with data that needs to be embedded.
97100 processed_callback
101+ Deprecated: remove in 3.33
98102 A function that is called after each item is embedded
99103 by either getting a successful response from the server,
100104 getting the result from cache or skipping the item.
105+ callback
106+ Callback for reporting the progress in share of embedded items
101107
102108 Returns
103109 -------
@@ -119,15 +125,18 @@ def embedd_data(
119125 asyncio .set_event_loop (loop )
120126 try :
121127 embeddings = asyncio .get_event_loop ().run_until_complete (
122- self .embedd_batch (data , processed_callback )
128+ self .embedd_batch (data , processed_callback , callback )
123129 )
124130 finally :
125131 loop .close ()
126132
127133 return embeddings
128134
129135 async def embedd_batch (
130- self , data : List [Any ], proc_callback : Callable [[bool ], None ] = None
136+ self ,
137+ data : List [Any ],
138+ processed_calback : Optional [Callable ] = None ,
139+ callback : Callable = dummy_callback ,
131140 ) -> List [Optional [List [float ]]]:
132141 """
133142 Function perform embedding of a batch of data items.
@@ -136,10 +145,8 @@ async def embedd_batch(
136145 ----------
137146 data
138147 A list of data that must be embedded.
139- proc_callback
140- A function that is called after each item is fully processed
141- by either getting a successful response from the server,
142- getting the result from cache or skipping the item.
148+ callback
149+ Callback for reporting the progress in share of embedded items
143150
144151 Returns
145152 -------
@@ -151,6 +158,22 @@ async def embedd_batch(
151158 EmbeddingCancelledException:
152159 If cancelled attribute is set to True (default=False).
153160 """
161+ # in Orange 3.33 keep content of the if - remove if clause and complete else
162+ if processed_calback is None :
163+ progress_items = iter (linspace (0 , 1 , len (data )))
164+
165+ def success_callback ():
166+ """Callback called on every successful embedding"""
167+ callback (next (progress_items ))
168+
169+ else :
170+ warnings .warn (
171+ "process_callback is deprecated and will be removed in version 3.33, "
172+ "use callback instead" ,
173+ FutureWarning ,
174+ )
175+ success_callback = processed_calback
176+
154177 results = [None ] * len (data )
155178 queue = asyncio .Queue ()
156179
@@ -161,7 +184,7 @@ async def embedd_batch(
161184 async with AsyncClient (
162185 timeout = self .timeout , base_url = self .server_url , proxies = get_proxies ()
163186 ) as client :
164- tasks = self ._init_workers (client , queue , results , proc_callback )
187+ tasks = self ._init_workers (client , queue , results , success_callback )
165188
166189 # wait for the queue to complete or one of workers to exit
167190 queue_complete = asyncio .create_task (queue .join ())
@@ -203,6 +226,7 @@ async def _cancel_workers(tasks):
203226 await asyncio .gather (* tasks , return_exceptions = True )
204227 log .debug ("All workers canceled" )
205228
229+ # remove in 3.33
206230 def __check_cancelled (self ):
207231 if self ._cancelled :
208232 raise EmbeddingCancelledException ()
@@ -252,6 +276,7 @@ async def _send_to_server(
252276 getting the result from cache or skipping the item.
253277 """
254278 while not queue .empty ():
279+ # remove in 3.33
255280 self .__check_cancelled ()
256281
257282 # get item from the queue
@@ -284,7 +309,7 @@ async def _send_to_server(
284309 log .debug ("Successfully embedded: %s" , cache_key )
285310 results [i ] = emb
286311 if proc_callback :
287- proc_callback (emb is not None )
312+ proc_callback ()
288313 elif num_repeats < self .MAX_REPEATS :
289314 log .debug ("Embedding unsuccessful - reading to queue: %s" , cache_key )
290315 # if embedding not successful put the item to queue to be handled at
@@ -379,5 +404,11 @@ def _parse_response(response: Response) -> Optional[List[float]]:
379404 def clear_cache (self ):
380405 self ._cache .clear_cache ()
381406
407+ # remove in 3.33
382408 def set_cancelled (self ):
409+ warnings .warn (
410+ "set_cancelled is deprecated and will be removed in version 3.33, "
411+ "the process can be canceled by raising Error in callback" ,
412+ FutureWarning ,
413+ )
383414 self ._cancelled = True
0 commit comments