Skip to content

Commit 21b157e

Browse files
committed
feat: stop experiments early
1 parent 0769b4b commit 21b157e

File tree

3 files changed

+29
-14
lines changed

3 files changed

+29
-14
lines changed

parea/experiment/experiment.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from parea.experiment.dvc import save_results_to_dvc_if_init
2020
from parea.helpers import duplicate_dicts, gen_random_name, is_logging_disabled
2121
from parea.schemas import EvaluationResult
22-
from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, FinishExperimentRequestSchema
22+
from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, \
23+
FinishExperimentRequestSchema, ExperimentStatus
2324
from parea.utils.trace_utils import thread_ids_running_evals, trace_data
2425
from parea.utils.universal_encoder import json_dumps
2526

@@ -138,13 +139,23 @@ def limit_concurrency_sync(sample):
138139
return func(_parea_target_field=target, **sample_copy)
139140

140141
if inspect.iscoroutinefunction(func):
141-
tasks = [limit_concurrency(sample) for sample in data]
142+
tasks = [asyncio.ensure_future(limit_concurrency(sample)) for sample in data]
142143
else:
143144
executor = ThreadPoolExecutor(max_workers=n_workers)
144145
loop = asyncio.get_event_loop()
145-
tasks = [loop.run_in_executor(executor, partial(limit_concurrency_sync, sample)) for sample in data]
146-
for _task in tqdm_asyncio.as_completed(tasks, total=len_test_cases):
147-
await _task
146+
tasks = [asyncio.ensure_future(loop.run_in_executor(executor, partial(limit_concurrency_sync, sample))) for
147+
sample in data]
148+
149+
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
150+
status = ExperimentStatus.COMPLETED
151+
for task in done:
152+
try:
153+
await task
154+
except Exception as e:
155+
print(f"Experiment stopped due to an error: {str(e)}")
156+
for _p in pending:
157+
_p.cancel()
158+
status = ExperimentStatus.FAILED
148159

149160
await asyncio.sleep(0.2)
150161
total_evals = len(thread_ids_running_evals.get())
@@ -162,7 +173,7 @@ def limit_concurrency_sync(sample):
162173
else:
163174
dataset_level_eval_results = []
164175

165-
experiment_stats: ExperimentStatsSchema = p.finish_experiment(experiment_uuid, FinishExperimentRequestSchema(dataset_level_stats=dataset_level_eval_results))
176+
experiment_stats: ExperimentStatsSchema = p.finish_experiment(experiment_uuid, FinishExperimentRequestSchema(status=status, dataset_level_stats=dataset_level_eval_results))
166177
stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats)
167178
if dataset_level_eval_results:
168179
stat_name_to_avg_std.update(**{eval_result.name: eval_result.score for eval_result in dataset_level_eval_results})

parea/schemas/models.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,16 @@ class CreateTestCaseCollection(CreateTestCases):
329329
column_names: List[str] = field(factory=list)
330330

331331

332+
class ExperimentStatus(str, Enum):
333+
PENDING = "pending"
334+
RUNNING = "running"
335+
COMPLETED = "completed"
336+
FAILED = "failed"
337+
338+
332339
@define
333340
class FinishExperimentRequestSchema:
341+
status: ExperimentStatus
334342
dataset_level_stats: Optional[List[EvaluationResult]] = field(factory=list)
335343

336344

@@ -343,13 +351,6 @@ class ListExperimentUUIDsFilters:
343351
experiment_uuids: Optional[List[str]] = None
344352

345353

346-
class ExperimentStatus(str, Enum):
347-
PENDING = "pending"
348-
RUNNING = "running"
349-
COMPLETED = "completed"
350-
FAILED = "failed"
351-
352-
353354
class StatisticOperation(str, Enum):
354355
MEAN = "mean"
355356
MEDIAN = "median"

parea/utils/trace_integrations/instructor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,11 @@ def __call__(
7171
for key in ["max_retries", "response_model", "validation_context", "mode", "args"]:
7272
if kwargs.get(key):
7373
metadata[key] = kwargs[key]
74+
trace_name = "instructor"
75+
if "response_model" in kwargs and kwargs["response_model"] and hasattr(kwargs["response_model"], "__name__"):
76+
trace_name = kwargs["response_model"].__name__
7477
return trace(
75-
name="instructor",
78+
name=trace_name,
7679
overwrite_trace_id=trace_id,
7780
overwrite_inputs=inputs,
7881
metadata=metadata,

0 commit comments

Comments
 (0)