@@ -44,6 +44,7 @@ def generate_batches(
4444 attribute_values_raw : List [str ],
4545 embedder : Transformer ,
4646 attribute_name : str ,
47+ for_delta : bool = False ,
4748) -> Iterator [Dict [List [str ], List [Any ]]]:
4849 length = len (record_ids )
4950 record_batches = []
@@ -61,7 +62,10 @@ def generate_batches(
6162 record_batches .append (record_ids_batch )
6263 document_batches .extend (documents )
6364
64- embedding_batches = embedder .fit_transform (document_batches , as_generator = True )
65+ if for_delta :
66+ embedding_batches = embedder .transform (document_batches , as_generator = True )
67+ else :
68+ embedding_batches = embedder .fit_transform (document_batches , as_generator = True )
6569 for record_batch in record_batches :
6670 yield {"record_ids" : record_batch , "embeddings" : next (embedding_batches )}
6771
@@ -193,6 +197,13 @@ def run_encoding(
193197 initial_count = record .count_attribute_list_entries (project_id , attribute_name )
194198 else :
195199 initial_count = record .count (project_id )
200+
201+ is_delta = False
202+ # refinery gateway handles delta logic beforehand so if count is 0 we can be sure it's not a delta
203+ if tensor_count := embedding .get_tensor_count (embedding_id ) != 0 :
204+ is_delta = True
205+ initial_count -= tensor_count
206+
196207 seed_str = embedding_name
197208 torch .manual_seed (zlib .adler32 (bytes (seed_str , "utf-8" )))
198209 notification .create (
@@ -214,15 +225,18 @@ def run_encoding(
214225 else :
215226 config_string = model
216227
217- embedder = get_embedder (
218- project_id ,
219- embedding_type ,
220- iso2_code ,
221- platform ,
222- model ,
223- api_token ,
224- additional_data ,
225- )
228+ if is_delta :
229+ embedder = __setup_tmp_embedder (project_id , embedding_id )
230+ else :
231+ embedder = get_embedder (
232+ project_id ,
233+ embedding_type ,
234+ iso2_code ,
235+ platform ,
236+ model ,
237+ api_token ,
238+ additional_data ,
239+ )
226240
227241 if not embedder :
228242 raise Exception (
@@ -253,7 +267,7 @@ def run_encoding(
253267
254268 try :
255269 record_ids , attribute_values_raw = record .get_attribute_data (
256- project_id , attribute_name
270+ project_id , attribute_name , is_delta , embedding_id
257271 )
258272 embedding .update_embedding_state_encoding (
259273 project_id ,
@@ -279,7 +293,8 @@ def run_encoding(
279293 True ,
280294 )
281295 send_project_update (project_id , f"notification_created:{ user_id } " , True )
282- embedding .delete_tensors (embedding_id , with_commit = True )
296+ if not is_delta :
297+ embedding .delete_tensors (embedding_id , with_commit = True )
283298 chunk = 0
284299 embedding_canceled = False
285300 for pair in generate_batches (
@@ -289,6 +304,7 @@ def run_encoding(
289304 attribute_values_raw ,
290305 embedder ,
291306 attribute_name ,
307+ for_delta = is_delta ,
292308 ):
293309 if chunk % 10 == 0 :
294310 session_token = general .remove_and_refresh_session (session_token , True )
0 commit comments