11import logging
22import traceback
3+ import time
4+ from typing import Any , List
35
46from controller import organization
7+ from controller .embedding import util as embedding_util
8+ from controller .embedding import connector as embedding_connector
59from starlette .endpoints import HTTPEndpoint
610from starlette .responses import PlainTextResponse , JSONResponse
711
812from controller .transfer .labelstudio import import_preperator
13+ from submodules .model .business_objects .tokenization import is_doc_bin_creation_running
914from submodules .s3 import controller as s3
10- from submodules .model .business_objects import organization
15+ from submodules .model .business_objects import (
16+ attribute ,
17+ embedding ,
18+ general ,
19+ organization ,
20+ tokenization ,
21+ )
1122
1223from controller .transfer import manager as transfer_manager
1324from controller .upload_task import manager as upload_task_manager
1627from controller .transfer import association_transfer_manager
1728from controller .auth import manager as auth
1829from controller .project import manager as project_manager
30+ from controller .attribute import manager as attribute_manager
1931
2032from submodules .model import enums , exceptions
2133from util .notification import create_notification
22- from submodules .model .enums import NotificationType
23- from submodules .model .models import UploadTask
24- from submodules .model .business_objects import general
25- from util import notification
34+ from submodules .model .enums import AttributeState , NotificationType , UploadStates
35+ from submodules .model .models import Embedding , UploadTask
36+ from util import daemon , notification
2637from controller .tokenization import tokenization_service
2738
2839logging .basicConfig (level = logging .DEBUG )
@@ -221,6 +232,7 @@ def init_file_import(task: UploadTask, project_id: str, is_global_update: bool)
221232 import_preperator .prepare_label_studio_import (project_id , task )
222233 else :
223234 transfer_manager .import_records_from_file (project_id , task )
235+ calculate_missing_attributes (project_id , task .user_id )
224236 elif "project" in task .file_type :
225237 transfer_manager .import_project (project_id , task )
226238 elif "knowledge_base" in task .file_type :
@@ -234,7 +246,10 @@ def init_file_import(task: UploadTask, project_id: str, is_global_update: bool)
234246 is_global_update ,
235247 )
236248 if task .file_type != "knowledge_base" :
237- tokenization_service .request_tokenize_project (project_id , str (task .user_id ))
249+ only_usable_attributes = task .file_type == "records_add"
250+ tokenization_service .request_tokenize_project (
251+ project_id , str (task .user_id ), True , only_usable_attributes
252+ )
238253
239254
240255def file_import_error_handling (
@@ -258,3 +273,160 @@ def file_import_error_handling(
258273 notification .send_organization_update (
259274 project_id , f"file_upload:{ str (task .id )} :state:{ task .state } " , is_global_update
260275 )
276+
277+
278+ def calculate_missing_attributes (project_id : str , user_id : str ) -> None :
279+ daemon .run (
280+ __calculate_missing_attributes ,
281+ project_id ,
282+ user_id ,
283+ )
284+
285+
286+ def __calculate_missing_attributes (project_id : str , user_id : str ) -> None :
287+ # wait a second to ensure that the process is started in the tokenization service
288+ time .sleep (5 )
289+ ctx_token = general .get_ctx_token ()
290+ attributes_usable = attribute .get_all_ordered (
291+ project_id ,
292+ True ,
293+ state_filter = [
294+ enums .AttributeState .USABLE .value ,
295+ ],
296+ )
297+ if len (attributes_usable ) == 0 :
298+ return
299+ # stored as list so connection results do not affect
300+ attribute_ids = [str (att_usable .id ) for att_usable in attributes_usable ]
301+ for att_id in attribute_ids :
302+ attribute .update (project_id , att_id , state = enums .AttributeState .INITIAL .value )
303+ general .commit ()
304+ notification .send_organization_update (
305+ project_id = project_id , message = "calculate_attribute:started:all"
306+ )
307+ # first check project tokenization completed
308+ i = 0
309+ while True :
310+ i += 1
311+ if i >= 60 :
312+ i = 0
313+ ctx_token = general .remove_and_refresh_session (ctx_token , True )
314+ if tokenization .is_doc_bin_creation_running (project_id ):
315+ time .sleep (5 )
316+ continue
317+ else :
318+ break
319+ # next, ensure that the attributes are calculated and tokenized
320+ i = 0
321+ while True :
322+ time .sleep (1 )
323+ i += 1
324+ if len (attribute_ids ) == 0 :
325+ notification .send_organization_update (
326+ project_id = project_id ,
327+ message = "calculate_attribute:finished:all" ,
328+ )
329+ break
330+ if i >= 60 :
331+ i = 0
332+ ctx_token = general .remove_and_refresh_session (ctx_token , True )
333+
334+ current_att_id = attribute_ids [0 ]
335+ current_att = attribute .get (project_id , current_att_id )
336+ if current_att .state == enums .AttributeState .RUNNING .value :
337+ continue
338+ elif current_att .state == enums .AttributeState .INITIAL .value :
339+ attribute_manager .calculate_user_attribute_all_records (
340+ project_id , user_id , current_att_id , True
341+ )
342+ else :
343+ if tokenization .is_doc_bin_creation_running_for_attribute (
344+ project_id , current_att .name
345+ ):
346+ time .sleep (5 )
347+ continue
348+ else :
349+ attribute_ids .pop (0 )
350+ notification .send_organization_update (
351+ project_id = project_id ,
352+ message = f"calculate_attribute:finished:{ current_att_id } " ,
353+ )
354+ time .sleep (5 )
355+
356+ general .remove_and_refresh_session (ctx_token , False )
357+ calculate_missing_embedding_tensors (project_id , user_id )
358+
359+
360+ def calculate_missing_embedding_tensors (project_id : str , user_id : str ) -> None :
361+ daemon .run (
362+ __calculate_missing_embedding_tensors ,
363+ project_id ,
364+ user_id ,
365+ )
366+
367+
368+ def __calculate_missing_embedding_tensors (project_id : str , user_id : str ) -> None :
369+ ctx_token = general .get_ctx_token ()
370+ embeddings = embedding .get_finished_embeddings_by_started_at (project_id )
371+ if len (embeddings ) == 0 :
372+ return
373+
374+ embedding_ids = [str (embed .id ) for embed in embeddings ]
375+ for embed_id in embedding_ids :
376+ embedding .update_embedding_state_waiting (project_id , embed_id )
377+ general .commit ()
378+
379+ try :
380+ ctx_token = __create_embeddings (project_id , embedding_ids , user_id , ctx_token )
381+ except Exception as e :
382+ print (
383+ f"Error while recreating embeddings for { project_id } when new records are uploaded : { e } "
384+ )
385+ get_waiting_embeddings = embedding .get_waiting_embeddings (project_id )
386+ for embed in get_waiting_embeddings :
387+ embedding .update_embedding_state_failed (project_id , str (embed .id ))
388+ general .commit ()
389+ finally :
390+ notification .send_organization_update (
391+ project_id = project_id , message = "embedding:finished:all"
392+ )
393+ general .remove_and_refresh_session (ctx_token , False )
394+
395+
396+ def __create_embeddings (
397+ project_id : str ,
398+ embedding_ids : List [str ],
399+ user_id : str ,
400+ ctx_token : Any ,
401+ ) -> Any :
402+ notification .send_organization_update (
403+ project_id = project_id , message = "embedding:started:all"
404+ )
405+ for embedding_id in embedding_ids :
406+ ctx_token = general .remove_and_refresh_session (ctx_token , request_new = True )
407+ embedding_item = embedding .get (project_id , embedding_id )
408+ if not embedding_item :
409+ continue
410+
411+ embedding_connector .request_deleting_embedding (project_id , embedding_id )
412+
413+ attribute_id = str (embedding_item .attribute_id )
414+ attribute_name = attribute .get (project_id , attribute_id ).name
415+ if embedding_item .type == enums .EmbeddingType .ON_ATTRIBUTE .value :
416+ prefix = f"{ attribute_name } -classification-"
417+ config_string = embedding_item .name [len (prefix ) :]
418+ embedding_connector .request_creating_attribute_level_embedding (
419+ project_id , attribute_id , user_id , config_string
420+ )
421+ else :
422+ prefix = f"{ attribute_name } -extraction-"
423+ config_string = embedding_item .name [len (prefix ) :]
424+ embedding_connector .request_creating_token_level_embedding (
425+ project_id , attribute_id , user_id , config_string
426+ )
427+ time .sleep (5 )
428+ while embedding_util .has_encoder_running (project_id ):
429+ if embedding_item .state == enums .EmbeddingState .WAITING .value :
430+ break
431+ time .sleep (1 )
432+ return ctx_token
0 commit comments