Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions run_ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def get_check_data_type_function(data_type: str) -> Tuple[List[Type], Callable]:
return [list], __check_data_type_embedding_list
elif data_type == "LLM_RESPONSE":
return [str], __check_data_type_text
elif data_type == "TEXT_LIST":
return [list], __check_data_type_text_list
else:
raise ValueError(f"Unknown data type: {data_type}")

Expand Down Expand Up @@ -73,6 +75,15 @@ def __check_data_type_embedding_list(attr_value: Any) -> bool:
return True


def __check_data_type_text_list(attr_value: Any) -> bool:
if not isinstance(attr_value, list):
return False
for e in attr_value:
if not isinstance(e, str):
return False
return True


def __print_progress_a2vybg(progress: float) -> None:
print(f"progress: {progress}", flush=True)

Expand All @@ -98,7 +109,7 @@ def load_data_dict_a2vybg(record: Dict[str, Any]) -> Dict[str, Any]:


def parse_data_to_record_dict_a2vybg(
record_chunk: List[Dict[str, Any]]
record_chunk: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
result = []
for r in record_chunk:
Expand Down Expand Up @@ -134,7 +145,7 @@ def save_ac_value_a2vybg(record_id: str, attr_value: Any) -> None:


def process_attribute_calculation_a2vybg(
record_dict_list: List[Dict[str, Any]]
record_dict_list: List[Dict[str, Any]],
) -> None:
for record_dict in record_dict_list:
attr_value: Any = attribute_calculators.ac(record_dict["data"])
Expand All @@ -147,7 +158,7 @@ def check_abort_status_a2vybg() -> bool:


async def process_llm_record_batch_a2vybg(
record_dict_batch: List[Dict[str, Any]]
record_dict_batch: List[Dict[str, Any]],
) -> None:
global should_abort_a2vybg

Expand Down Expand Up @@ -178,7 +189,7 @@ def make_batches(


async def process_async_llm_calls_a2vybg(
record_dict_list: List[Dict[str, Any]]
record_dict_list: List[Dict[str, Any]],
) -> None:
batch_size = max(amount_a2vybg // int(attribute_calculators.NUM_WORKERS_A2VYBG), 1)
tasks = [
Expand Down