Skip to content

Commit e12aea1

Browse files
committed
perf: unprototype run_ac.py
1 parent 46699bd commit e12aea1

File tree

1 file changed

+81
-67
lines changed

1 file changed

+81
-67
lines changed

run_ac.py

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
from typing import Any, Dict, List, Generator, Callable, Tuple, Type
2+
import asyncio
13
import json
24
import requests
35
import spacy
46
import sys
5-
import asyncio
67
from mustache import prepare_and_render_mustache
78
from spacy.tokens import DocBin
89

910

10-
def get_check_data_type_function(data_type):
11+
def get_check_data_type_function(data_type: str) -> Tuple[List[Type], Callable]:
1112
if data_type == "INTEGER":
1213
return [int], __check_data_type_integer
1314
elif data_type == "FLOAT":
@@ -26,13 +27,13 @@ def get_check_data_type_function(data_type):
2627
raise ValueError(f"Unknown data type: {data_type}")
2728

2829

29-
def __check_data_type_integer(attr_value):
30+
def __check_data_type_integer(attr_value: Any) -> bool:
3031
if attr_value is not None and not isinstance(attr_value, int):
3132
return False
3233
return True
3334

3435

35-
def __check_data_type_float(attr_value):
36+
def __check_data_type_float(attr_value: Any) -> bool:
3637
if (
3738
attr_value is not None
3839
and not isinstance(attr_value, float)
@@ -42,27 +43,27 @@ def __check_data_type_float(attr_value):
4243
return True
4344

4445

45-
def __check_data_type_boolean(attr_value):
46+
def __check_data_type_boolean(attr_value: Any) -> bool:
4647
if not isinstance(attr_value, bool):
4748
return False
4849
return True
4950

5051

51-
def __check_data_type_category(attr_value):
52+
def __check_data_type_category(attr_value: Any) -> bool:
5253
if not isinstance(attr_value, str):
5354
return False
5455
if attr_value == "":
5556
raise ValueError("Category cannot be empty string")
5657
return True
5758

5859

59-
def __check_data_type_text(attr_value):
60+
def __check_data_type_text(attr_value: Any) -> bool:
6061
if not isinstance(attr_value, str):
6162
return False
6263
return True
6364

6465

65-
def __check_data_type_embedding_list(attr_value):
66+
def __check_data_type_embedding_list(attr_value: Any) -> bool:
6667
if not isinstance(attr_value, list):
6768
return False
6869
for e in attr_value:
@@ -75,7 +76,9 @@ def __print_progress(progress: float) -> None:
7576
print(f"progress: {progress}", flush=True)
7677

7778

78-
def load_data_dict(record):
79+
def load_data_dict(record: Dict[str, Any]) -> Dict[str, Any]:
80+
global vocab
81+
7982
if record["bytes"][:2] == "\\x":
8083
record["bytes"] = record["bytes"][2:]
8184
else:
@@ -95,13 +98,68 @@ def load_data_dict(record):
9598
return data_dict
9699

97100

98-
def parse_data_to_record_dict(record_chunk):
101+
def parse_data_to_record_dict(
102+
record_chunk: List[Dict[str, Any]]
103+
) -> List[Dict[str, Any]]:
99104
result = []
100105
for r in record_chunk:
101106
result.append({"id": r["record_id"], "data": load_data_dict(r)})
102107
return result
103108

104109

110+
def save_ac_value(record_id: str, attr_value: Any) -> None:
111+
global calculated_attribute_by_record_id, processed_records, progress_size, amount, check_data_type, py_data_types
112+
113+
if not check_data_type(attr_value):
114+
raise ValueError(
115+
f"Attribute value `{attr_value}` is of type {type(attr_value)}, "
116+
f"but data_type {data_type} requires "
117+
f"{str(py_data_types) if len(py_data_types) > 1 else str(py_data_types[0])}."
118+
)
119+
120+
calculated_attribute_by_record_id[record_id] = attr_value
121+
122+
processed_records = processed_records + 1
123+
if processed_records % progress_size == 0:
124+
__print_progress(round(processed_records / amount, 2))
125+
126+
127+
def process_attribute_calculation(record_dict_list: List[Dict[str, Any]]) -> None:
128+
for record_dict in record_dict_list:
129+
attr_value: Any = attribute_calculators.ac(record_dict["data"])
130+
save_ac_value(record_dict["id"], attr_value)
131+
132+
133+
async def process_llm_record_batch(record_dict_batch: List[Dict[str, Any]]) -> None:
134+
global DEFAULT_USER_PROMPT_A2VYBG
135+
136+
for record_dict in record_dict_batch:
137+
attribute_calculators.USER_PROMPT_A2VYBG = prepare_and_render_mustache(
138+
DEFAULT_USER_PROMPT_A2VYBG, record_dict
139+
)
140+
141+
attr_value: str = await attribute_calculators.ac(record_dict["data"])
142+
save_ac_value(record_dict["id"], attr_value)
143+
144+
145+
async def process_async_llm_calls(record_dict_list: List[Dict[str, Any]]) -> None:
146+
global amount
147+
148+
def make_batches(
149+
iterable: List[Any], size: int = 1
150+
) -> Generator[List[Any], None, None]:
151+
length = len(iterable)
152+
for ndx in range(0, length, size):
153+
yield iterable[ndx : min(ndx + size, length)]
154+
155+
batch_size = max(amount // int(attribute_calculators.NUM_WORKERS_A2VYBG), 1)
156+
tasks = [
157+
process_llm_record_batch(batch)
158+
for batch in make_batches(record_dict_list, size=batch_size)
159+
]
160+
await asyncio.gather(*tasks)
161+
162+
105163
if __name__ == "__main__":
106164
_, iso2_code, payload_url, data_type = sys.argv
107165

@@ -111,8 +169,9 @@ def parse_data_to_record_dict(record_chunk):
111169
# the script `labeling_functions` does not exist. It will be inserted at runtime
112170
import attribute_calculators
113171

114-
if data_type == "LLM_RESPONSE":
115-
DEFAULT_USER_PROMPT_A2VYBG = attribute_calculators.USER_PROMPT_A2VYBG
172+
DEFAULT_USER_PROMPT_A2VYBG = getattr(
173+
attribute_calculators, "USER_PROMPT_A2VYBG", None
174+
)
116175

117176
vocab = spacy.blank(iso2_code).vocab
118177

@@ -125,64 +184,19 @@ def parse_data_to_record_dict(record_chunk):
125184

126185
print("Running attribute calculation.")
127186
calculated_attribute_by_record_id = {}
128-
idx = 0
129187
amount = len(record_dict_list)
130-
progress_size = min(100, amount // 10)
188+
progress_size = min(
189+
100,
190+
max(amount // int(getattr(attribute_calculators, "NUM_WORKERS_A2VYBG", 1)), 1),
191+
)
131192
processed_records = 0
132-
__print_progress(processed_records / amount)
133-
134-
async def process_llm_record_batch(record_dict_batch: list):
135-
"""Process a batch of record_dicts, writes results into shared var calculated_attribute_by_record_id."""
136-
137-
for record_dict in record_dict_batch:
138-
attribute_calculators.USER_PROMPT_A2VYBG = prepare_and_render_mustache(
139-
DEFAULT_USER_PROMPT_A2VYBG, record_dict
140-
)
141-
142-
attr_value: str = await attribute_calculators.ac(record_dict["data"])
143-
144-
if not check_data_type(attr_value):
145-
raise ValueError(
146-
f"Attribute value `{attr_value}` is of type {type(attr_value)}, "
147-
f"but data_type {data_type} requires "
148-
f"{str(py_data_types) if len(py_data_types) > 1 else str(py_data_types[0])}."
149-
)
150-
calculated_attribute_by_record_id[record_dict["id"]] = attr_value
151-
global processed_records
152-
processed_records = processed_records + 1
153-
if processed_records % progress_size == 0:
154-
__print_progress(round(processed_records / amount, 2))
155-
156-
async def process_async_llm_calls(record_dict_list):
157-
batch_size = max(
158-
len(record_dict_list) // int(attribute_calculators.NUM_WORKERS_A2VYBG), 1
159-
)
160-
record_dict_batches = [
161-
record_dict_list[i : i + batch_size]
162-
for i in range(0, len(record_dict_list), batch_size)
163-
]
164-
tasks = [process_llm_record_batch(batch) for batch in record_dict_batches]
165-
await asyncio.gather(*tasks)
193+
__print_progress(0.0)
166194

167195
if data_type == "LLM_RESPONSE":
168196
asyncio.run(process_async_llm_calls(record_dict_list))
169-
requests.put(payload_url, json=calculated_attribute_by_record_id)
170-
__print_progress(1.0)
171-
print("Finished execution.")
172197
else:
173-
for record_dict in record_dict_list:
174-
idx += 1
175-
if idx % progress_size == 0:
176-
progress = round(idx / amount, 2)
177-
__print_progress(progress)
178-
attr_value = attribute_calculators.ac(record_dict["data"])
179-
if not check_data_type(attr_value):
180-
raise ValueError(
181-
f"Attribute value `{attr_value}` is of type {type(attr_value)}, "
182-
f"but data_type {data_type} requires "
183-
f"{str(py_data_types) if len(py_data_types) > 1 else str(py_data_types[0])}."
184-
)
185-
calculated_attribute_by_record_id[record_dict["id"]] = attr_value
186-
__print_progress(1.0)
187-
print("Finished execution.")
188-
requests.put(payload_url, json=calculated_attribute_by_record_id)
198+
process_attribute_calculation(record_dict_list)
199+
200+
__print_progress(1.0)
201+
print("Finished execution.")
202+
requests.put(payload_url, json=calculated_attribute_by_record_id)

0 commit comments

Comments
 (0)