Skip to content

Commit 998c272

Browse files
add text List (#82)
1 parent cb1a832 commit 998c272

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

run_ac.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def get_check_data_type_function(data_type: str) -> Tuple[List[Type], Callable]:
2424
return [list], __check_data_type_embedding_list
2525
elif data_type == "LLM_RESPONSE":
2626
return [str], __check_data_type_text
27+
elif data_type == "TEXT_LIST":
28+
return [list], __check_data_type_text_list
2729
else:
2830
raise ValueError(f"Unknown data type: {data_type}")
2931

@@ -73,6 +75,15 @@ def __check_data_type_embedding_list(attr_value: Any) -> bool:
7375
return True
7476

7577

78+
def __check_data_type_text_list(attr_value: Any) -> bool:
79+
if not isinstance(attr_value, list):
80+
return False
81+
for e in attr_value:
82+
if not isinstance(e, str):
83+
return False
84+
return True
85+
86+
7687
def __print_progress_a2vybg(progress: float) -> None:
7788
print(f"progress: {progress}", flush=True)
7889

@@ -98,7 +109,7 @@ def load_data_dict_a2vybg(record: Dict[str, Any]) -> Dict[str, Any]:
98109

99110

100111
def parse_data_to_record_dict_a2vybg(
101-
record_chunk: List[Dict[str, Any]]
112+
record_chunk: List[Dict[str, Any]],
102113
) -> List[Dict[str, Any]]:
103114
result = []
104115
for r in record_chunk:
@@ -134,7 +145,7 @@ def save_ac_value_a2vybg(record_id: str, attr_value: Any) -> None:
134145

135146

136147
def process_attribute_calculation_a2vybg(
137-
record_dict_list: List[Dict[str, Any]]
148+
record_dict_list: List[Dict[str, Any]],
138149
) -> None:
139150
for record_dict in record_dict_list:
140151
attr_value: Any = attribute_calculators.ac(record_dict["data"])
@@ -147,7 +158,7 @@ def check_abort_status_a2vybg() -> bool:
147158

148159

149160
async def process_llm_record_batch_a2vybg(
150-
record_dict_batch: List[Dict[str, Any]]
161+
record_dict_batch: List[Dict[str, Any]],
151162
) -> None:
152163
global should_abort_a2vybg
153164

@@ -178,7 +189,7 @@ def make_batches(
178189

179190

180191
async def process_async_llm_calls_a2vybg(
181-
record_dict_list: List[Dict[str, Any]]
192+
record_dict_list: List[Dict[str, Any]],
182193
) -> None:
183194
batch_size = max(amount_a2vybg // int(attribute_calculators.NUM_WORKERS_A2VYBG), 1)
184195
tasks = [

0 commit comments

Comments
 (0)