22import requests
33import spacy
44import sys
5+ import asyncio
56from mustache import prepare_and_render_mustache
67from spacy .tokens import DocBin
78
@@ -110,7 +111,8 @@ def parse_data_to_record_dict(record_chunk):
110111 # the script `labeling_functions` does not exist. It will be inserted at runtime
111112 import attribute_calculators
112113
113- DEFAULT_USER_PROMPT_A2VYBG = attribute_calculators .USER_PROMPT_A2VYBG
114+ if data_type == "LLM_RESPONSE" :
115+ DEFAULT_USER_PROMPT_A2VYBG = attribute_calculators .USER_PROMPT_A2VYBG
114116
115117 vocab = spacy .blank (iso2_code ).vocab
116118
@@ -127,23 +129,53 @@ def parse_data_to_record_dict(record_chunk):
127129 progress_size = 100
128130 amount = len (record_dict_list )
129131 __print_progress (0.0 )
130- for record_dict in record_dict_list :
131- attribute_calculators .USER_PROMPT_A2VYBG = prepare_and_render_mustache (
132- DEFAULT_USER_PROMPT_A2VYBG , record_dict
133- )
134-
135- idx += 1
136- if idx % progress_size == 0 :
137- progress = round (idx / amount , 2 )
138- __print_progress (progress )
139- attr_value = attribute_calculators .ac (record_dict ["data" ])
140- if not check_data_type (attr_value ):
141- raise ValueError (
142- f"Attribute value `{ attr_value } ` is of type { type (attr_value )} , "
143- f"but data_type { data_type } requires "
144- f"{ str (py_data_types ) if len (py_data_types ) > 1 else str (py_data_types [0 ])} ."
132+
133+ async def process_llm_record_batch (record_dict_batch : list ):
134+ """Process a batch of record_dicts, writes results into shared var calculated_attribute_by_record_id."""
135+
136+ for record_dict in record_dict_batch :
137+ attribute_calculators .USER_PROMPT_A2VYBG = prepare_and_render_mustache (
138+ DEFAULT_USER_PROMPT_A2VYBG , record_dict
145139 )
146- calculated_attribute_by_record_id [record_dict ["id" ]] = attr_value
147- __print_progress (1.0 )
148- print ("Finished execution." )
149- requests .put (payload_url , json = calculated_attribute_by_record_id )
140+
141+ attr_value : str = await attribute_calculators .ac (record_dict ["data" ])
142+
143+ if not check_data_type (attr_value ):
144+ raise ValueError (
145+ f"Attribute value `{ attr_value } ` is of type { type (attr_value )} , "
146+ f"but data_type { data_type } requires "
147+ f"{ str (py_data_types ) if len (py_data_types ) > 1 else str (py_data_types [0 ])} ."
148+ )
149+ calculated_attribute_by_record_id [record_dict ["id" ]] = attr_value
150+
151+ async def process_async_llm_calls (record_dict_list ):
152+ batch_size = len (record_dict_list ) // int (attribute_calculators .NUM_WORKERS )
153+ record_dict_batches = [
154+ record_dict_list [i : i + batch_size ]
155+ for i in range (0 , len (record_dict_list ), batch_size )
156+ ]
157+ tasks = [process_llm_record_batch (batch ) for batch in record_dict_batches ]
158+ await asyncio .gather (* tasks )
159+
160+ if data_type == "LLM_RESPONSE" :
161+ asyncio .run (process_async_llm_calls (record_dict_list ))
162+ requests .put (payload_url , json = calculated_attribute_by_record_id )
163+ __print_progress (1.0 )
164+ print ("Finished execution." )
165+ else :
166+ for record_dict in record_dict_list :
167+ idx += 1
168+ if idx % progress_size == 0 :
169+ progress = round (idx / amount , 2 )
170+ __print_progress (progress )
171+ attr_value = attribute_calculators .ac (record_dict ["data" ])
172+ if not check_data_type (attr_value ):
173+ raise ValueError (
174+ f"Attribute value `{ attr_value } ` is of type { type (attr_value )} , "
175+ f"but data_type { data_type } requires "
176+ f"{ str (py_data_types ) if len (py_data_types ) > 1 else str (py_data_types [0 ])} ."
177+ )
178+ calculated_attribute_by_record_id [record_dict ["id" ]] = attr_value
179+ __print_progress (1.0 )
180+ print ("Finished execution." )
181+ requests .put (payload_url , json = calculated_attribute_by_record_id )
0 commit comments