1919from parea .experiment .dvc import save_results_to_dvc_if_init
2020from parea .helpers import duplicate_dicts , gen_random_name , is_logging_disabled
2121from parea .schemas import EvaluationResult
22- from parea .schemas .models import CreateExperimentRequest , ExperimentSchema , ExperimentStatsSchema , FinishExperimentRequestSchema
22+ from parea .schemas .models import CreateExperimentRequest , ExperimentSchema , ExperimentStatsSchema , ExperimentStatus , FinishExperimentRequestSchema
2323from parea .utils .trace_utils import thread_ids_running_evals , trace_data
2424from parea .utils .universal_encoder import json_dumps
2525
@@ -138,13 +138,26 @@ def limit_concurrency_sync(sample):
138138 return func (_parea_target_field = target , ** sample_copy )
139139
140140 if inspect .iscoroutinefunction (func ):
141- tasks = [limit_concurrency (sample ) for sample in data ]
141+ tasks = [asyncio . ensure_future ( limit_concurrency (sample ) ) for sample in data ]
142142 else :
143143 executor = ThreadPoolExecutor (max_workers = n_workers )
144144 loop = asyncio .get_event_loop ()
145- tasks = [loop .run_in_executor (executor , partial (limit_concurrency_sync , sample )) for sample in data ]
146- for _task in tqdm_asyncio .as_completed (tasks , total = len_test_cases ):
147- await _task
145+ tasks = [asyncio .ensure_future (loop .run_in_executor (executor , partial (limit_concurrency_sync , sample ))) for sample in data ]
146+
147+ status = ExperimentStatus .COMPLETED
148+ with tqdm (total = len (tasks ), desc = "Running samples" , unit = "sample" ) as pbar :
149+ try :
150+ for coro in asyncio .as_completed (tasks ):
151+ try :
152+ await coro
153+ pbar .update (1 )
154+ except Exception as e :
155+ print (f"\n Experiment stopped due to an error: { str (e )} \n " )
156+ status = ExperimentStatus .FAILED
157+ for task in tasks :
158+ task .cancel ()
159+ except asyncio .CancelledError :
160+ pass
148161
149162 await asyncio .sleep (0.2 )
150163 total_evals = len (thread_ids_running_evals .get ())
@@ -162,7 +175,7 @@ def limit_concurrency_sync(sample):
162175 else :
163176 dataset_level_eval_results = []
164177
165- experiment_stats : ExperimentStatsSchema = p .finish_experiment (experiment_uuid , FinishExperimentRequestSchema (dataset_level_stats = dataset_level_eval_results ))
178+ experiment_stats : ExperimentStatsSchema = p .finish_experiment (experiment_uuid , FinishExperimentRequestSchema (status = status , dataset_level_stats = dataset_level_eval_results ))
166179 stat_name_to_avg_std = calculate_avg_std_for_experiment (experiment_stats )
167180 if dataset_level_eval_results :
168181 stat_name_to_avg_std .update (** {eval_result .name : eval_result .score for eval_result in dataset_level_eval_results })
0 commit comments