1919from parea .experiment .dvc import save_results_to_dvc_if_init
2020from parea .helpers import duplicate_dicts , gen_random_name , is_logging_disabled
2121from parea .schemas import EvaluationResult
22- from parea .schemas .models import CreateExperimentRequest , ExperimentSchema , ExperimentStatsSchema , FinishExperimentRequestSchema
22+ from parea .schemas .models import CreateExperimentRequest , ExperimentSchema , ExperimentStatsSchema , \
23+ FinishExperimentRequestSchema , ExperimentStatus
2324from parea .utils .trace_utils import thread_ids_running_evals , trace_data
2425from parea .utils .universal_encoder import json_dumps
2526
@@ -138,13 +139,23 @@ def limit_concurrency_sync(sample):
138139 return func (_parea_target_field = target , ** sample_copy )
139140
140141 if inspect .iscoroutinefunction (func ):
141- tasks = [limit_concurrency (sample ) for sample in data ]
142+ tasks = [asyncio . ensure_future ( limit_concurrency (sample ) ) for sample in data ]
142143 else :
143144 executor = ThreadPoolExecutor (max_workers = n_workers )
144145 loop = asyncio .get_event_loop ()
145- tasks = [loop .run_in_executor (executor , partial (limit_concurrency_sync , sample )) for sample in data ]
146- for _task in tqdm_asyncio .as_completed (tasks , total = len_test_cases ):
147- await _task
146+ tasks = [asyncio .ensure_future (loop .run_in_executor (executor , partial (limit_concurrency_sync , sample ))) for
147+ sample in data ]
148+
149+ done , pending = await asyncio .wait (tasks , return_when = asyncio .FIRST_EXCEPTION )
150+ status = ExperimentStatus .COMPLETED
151+ for task in done :
152+ try :
153+ await task
154+ except Exception as e :
155+ print (f"Experiment stopped due to an error: { str (e )} " )
156+ for _p in pending :
157+ _p .cancel ()
158+ status = ExperimentStatus .FAILED
148159
149160 await asyncio .sleep (0.2 )
150161 total_evals = len (thread_ids_running_evals .get ())
@@ -162,7 +173,7 @@ def limit_concurrency_sync(sample):
162173 else :
163174 dataset_level_eval_results = []
164175
165- experiment_stats : ExperimentStatsSchema = p .finish_experiment (experiment_uuid , FinishExperimentRequestSchema (dataset_level_stats = dataset_level_eval_results ))
176+ experiment_stats : ExperimentStatsSchema = p .finish_experiment (experiment_uuid , FinishExperimentRequestSchema (status = status , dataset_level_stats = dataset_level_eval_results ))
166177 stat_name_to_avg_std = calculate_avg_std_for_experiment (experiment_stats )
167178 if dataset_level_eval_results :
168179 stat_name_to_avg_std .update (** {eval_result .name : eval_result .score for eval_result in dataset_level_eval_results })
0 commit comments