1+ from typing import Any , Dict , List , Generator , Callable , Tuple , Type
2+ import asyncio
13import json
24import requests
35import spacy
46import sys
5- import asyncio
67from mustache import prepare_and_render_mustache
78from spacy .tokens import DocBin
89
910
10- def get_check_data_type_function (data_type ) :
11+ def get_check_data_type_function (data_type : str ) -> Tuple [ List [ Type ], Callable ] :
1112 if data_type == "INTEGER" :
1213 return [int ], __check_data_type_integer
1314 elif data_type == "FLOAT" :
@@ -26,13 +27,13 @@ def get_check_data_type_function(data_type):
2627 raise ValueError (f"Unknown data type: { data_type } " )
2728
2829
29- def __check_data_type_integer (attr_value ) :
30+ def __check_data_type_integer (attr_value : Any ) -> bool :
3031 if attr_value is not None and not isinstance (attr_value , int ):
3132 return False
3233 return True
3334
3435
35- def __check_data_type_float (attr_value ) :
36+ def __check_data_type_float (attr_value : Any ) -> bool :
3637 if (
3738 attr_value is not None
3839 and not isinstance (attr_value , float )
@@ -42,27 +43,27 @@ def __check_data_type_float(attr_value):
4243 return True
4344
4445
45- def __check_data_type_boolean (attr_value ) :
46+ def __check_data_type_boolean (attr_value : Any ) -> bool :
4647 if not isinstance (attr_value , bool ):
4748 return False
4849 return True
4950
5051
51- def __check_data_type_category (attr_value ) :
52+ def __check_data_type_category (attr_value : Any ) -> bool :
5253 if not isinstance (attr_value , str ):
5354 return False
5455 if attr_value == "" :
5556 raise ValueError ("Category cannot be empty string" )
5657 return True
5758
5859
59- def __check_data_type_text (attr_value ) :
60+ def __check_data_type_text (attr_value : Any ) -> bool :
6061 if not isinstance (attr_value , str ):
6162 return False
6263 return True
6364
6465
65- def __check_data_type_embedding_list (attr_value ) :
66+ def __check_data_type_embedding_list (attr_value : Any ) -> bool :
6667 if not isinstance (attr_value , list ):
6768 return False
6869 for e in attr_value :
@@ -75,7 +76,9 @@ def __print_progress(progress: float) -> None:
7576 print (f"progress: { progress } " , flush = True )
7677
7778
78- def load_data_dict (record ):
79+ def load_data_dict (record : Dict [str , Any ]) -> Dict [str , Any ]:
80+ global vocab
81+
7982 if record ["bytes" ][:2 ] == "\\ x" :
8083 record ["bytes" ] = record ["bytes" ][2 :]
8184 else :
@@ -95,13 +98,68 @@ def load_data_dict(record):
9598 return data_dict
9699
97100
98- def parse_data_to_record_dict (record_chunk ):
101+ def parse_data_to_record_dict (
102+ record_chunk : List [Dict [str , Any ]]
103+ ) -> List [Dict [str , Any ]]:
99104 result = []
100105 for r in record_chunk :
101106 result .append ({"id" : r ["record_id" ], "data" : load_data_dict (r )})
102107 return result
103108
104109
110+ def save_ac_value (record_id : str , attr_value : Any ) -> None :
111+ global calculated_attribute_by_record_id , processed_records , progress_size , amount , check_data_type , py_data_types
112+
113+ if not check_data_type (attr_value ):
114+ raise ValueError (
115+ f"Attribute value `{ attr_value } ` is of type { type (attr_value )} , "
116+ f"but data_type { data_type } requires "
117+ f"{ str (py_data_types ) if len (py_data_types ) > 1 else str (py_data_types [0 ])} ."
118+ )
119+
120+ calculated_attribute_by_record_id [record_id ] = attr_value
121+
122+ processed_records = processed_records + 1
123+ if processed_records % progress_size == 0 :
124+ __print_progress (round (processed_records / amount , 2 ))
125+
126+
127+ def process_attribute_calculation (record_dict_list : List [Dict [str , Any ]]) -> None :
128+ for record_dict in record_dict_list :
129+ attr_value : Any = attribute_calculators .ac (record_dict ["data" ])
130+ save_ac_value (record_dict ["id" ], attr_value )
131+
132+
133+ async def process_llm_record_batch (record_dict_batch : List [Dict [str , Any ]]) -> None :
134+ global DEFAULT_USER_PROMPT_A2VYBG
135+
136+ for record_dict in record_dict_batch :
137+ attribute_calculators .USER_PROMPT_A2VYBG = prepare_and_render_mustache (
138+ DEFAULT_USER_PROMPT_A2VYBG , record_dict
139+ )
140+
141+ attr_value : str = await attribute_calculators .ac (record_dict ["data" ])
142+ save_ac_value (record_dict ["id" ], attr_value )
143+
144+
145+ async def process_async_llm_calls (record_dict_list : List [Dict [str , Any ]]) -> None :
146+ global amount
147+
148+ def make_batches (
149+ iterable : List [Any ], size : int = 1
150+ ) -> Generator [List [Any ], None , None ]:
151+ length = len (iterable )
152+ for ndx in range (0 , length , size ):
153+ yield iterable [ndx : min (ndx + size , length )]
154+
155+ batch_size = max (amount // int (attribute_calculators .NUM_WORKERS_A2VYBG ), 1 )
156+ tasks = [
157+ process_llm_record_batch (batch )
158+ for batch in make_batches (record_dict_list , size = batch_size )
159+ ]
160+ await asyncio .gather (* tasks )
161+
162+
105163if __name__ == "__main__" :
106164 _ , iso2_code , payload_url , data_type = sys .argv
107165
@@ -111,8 +169,9 @@ def parse_data_to_record_dict(record_chunk):
111169 # the script `labeling_functions` does not exist. It will be inserted at runtime
112170 import attribute_calculators
113171
114- if data_type == "LLM_RESPONSE" :
115- DEFAULT_USER_PROMPT_A2VYBG = attribute_calculators .USER_PROMPT_A2VYBG
172+ DEFAULT_USER_PROMPT_A2VYBG = getattr (
173+ attribute_calculators , "USER_PROMPT_A2VYBG" , None
174+ )
116175
117176 vocab = spacy .blank (iso2_code ).vocab
118177
@@ -125,64 +184,19 @@ def parse_data_to_record_dict(record_chunk):
125184
126185 print ("Running attribute calculation." )
127186 calculated_attribute_by_record_id = {}
128- idx = 0
129187 amount = len (record_dict_list )
130- progress_size = min (100 , amount // 10 )
188+ progress_size = min (
189+ 100 ,
190+ max (amount // int (getattr (attribute_calculators , "NUM_WORKERS_A2VYBG" , 1 )), 1 ),
191+ )
131192 processed_records = 0
132- __print_progress (processed_records / amount )
133-
134- async def process_llm_record_batch (record_dict_batch : list ):
135- """Process a batch of record_dicts, writes results into shared var calculated_attribute_by_record_id."""
136-
137- for record_dict in record_dict_batch :
138- attribute_calculators .USER_PROMPT_A2VYBG = prepare_and_render_mustache (
139- DEFAULT_USER_PROMPT_A2VYBG , record_dict
140- )
141-
142- attr_value : str = await attribute_calculators .ac (record_dict ["data" ])
143-
144- if not check_data_type (attr_value ):
145- raise ValueError (
146- f"Attribute value `{ attr_value } ` is of type { type (attr_value )} , "
147- f"but data_type { data_type } requires "
148- f"{ str (py_data_types ) if len (py_data_types ) > 1 else str (py_data_types [0 ])} ."
149- )
150- calculated_attribute_by_record_id [record_dict ["id" ]] = attr_value
151- global processed_records
152- processed_records = processed_records + 1
153- if processed_records % progress_size == 0 :
154- __print_progress (round (processed_records / amount , 2 ))
155-
156- async def process_async_llm_calls (record_dict_list ):
157- batch_size = max (
158- len (record_dict_list ) // int (attribute_calculators .NUM_WORKERS_A2VYBG ), 1
159- )
160- record_dict_batches = [
161- record_dict_list [i : i + batch_size ]
162- for i in range (0 , len (record_dict_list ), batch_size )
163- ]
164- tasks = [process_llm_record_batch (batch ) for batch in record_dict_batches ]
165- await asyncio .gather (* tasks )
193+ __print_progress (0.0 )
166194
167195 if data_type == "LLM_RESPONSE" :
168196 asyncio .run (process_async_llm_calls (record_dict_list ))
169- requests .put (payload_url , json = calculated_attribute_by_record_id )
170- __print_progress (1.0 )
171- print ("Finished execution." )
172197 else :
173- for record_dict in record_dict_list :
174- idx += 1
175- if idx % progress_size == 0 :
176- progress = round (idx / amount , 2 )
177- __print_progress (progress )
178- attr_value = attribute_calculators .ac (record_dict ["data" ])
179- if not check_data_type (attr_value ):
180- raise ValueError (
181- f"Attribute value `{ attr_value } ` is of type { type (attr_value )} , "
182- f"but data_type { data_type } requires "
183- f"{ str (py_data_types ) if len (py_data_types ) > 1 else str (py_data_types [0 ])} ."
184- )
185- calculated_attribute_by_record_id [record_dict ["id" ]] = attr_value
186- __print_progress (1.0 )
187- print ("Finished execution." )
188- requests .put (payload_url , json = calculated_attribute_by_record_id )
198+ process_attribute_calculation (record_dict_list )
199+
200+ __print_progress (1.0 )
201+ print ("Finished execution." )
202+ requests .put (payload_url , json = calculated_attribute_by_record_id )
0 commit comments