Skip to content

Commit e3f0c63

Browse files
authored
Merge pull request #934 from parea-ai/PAI-1264-finished-experiment-even-on-trace-fail-flag-to-fail-on-first
Pai 1264 finished experiment even on trace fail flag to fail on first
2 parents 94848cf + cecbc3e commit e3f0c63

File tree

4 files changed

+33
-15
lines changed

4 files changed

+33
-15
lines changed

parea/cookbook/evals_and_experiments/run_experiment_using_saved_test_collection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ def func(lang: str, framework: str) -> str:
3636

3737
if __name__ == "__main__":
3838
p.experiment(
39+
name="Hello World Example", # this is the name of the experiment
3940
data="Hello World Example", # this is the name of your Dataset in Parea (Dataset page)
4041
func=func,
41-
).run(name="hello-world-example")
42+
).run()
4243

4344
# Or use a dataset using its ID instead of the name
4445
# p.experiment(

parea/experiment/experiment.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from parea.experiment.dvc import save_results_to_dvc_if_init
2020
from parea.helpers import duplicate_dicts, gen_random_name, is_logging_disabled
2121
from parea.schemas import EvaluationResult
22-
from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, FinishExperimentRequestSchema
22+
from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, ExperimentStatus, FinishExperimentRequestSchema
2323
from parea.utils.trace_utils import thread_ids_running_evals, trace_data
2424
from parea.utils.universal_encoder import json_dumps
2525

@@ -138,13 +138,26 @@ def limit_concurrency_sync(sample):
138138
return func(_parea_target_field=target, **sample_copy)
139139

140140
if inspect.iscoroutinefunction(func):
141-
tasks = [limit_concurrency(sample) for sample in data]
141+
tasks = [asyncio.ensure_future(limit_concurrency(sample)) for sample in data]
142142
else:
143143
executor = ThreadPoolExecutor(max_workers=n_workers)
144144
loop = asyncio.get_event_loop()
145-
tasks = [loop.run_in_executor(executor, partial(limit_concurrency_sync, sample)) for sample in data]
146-
for _task in tqdm_asyncio.as_completed(tasks, total=len_test_cases):
147-
await _task
145+
tasks = [asyncio.ensure_future(loop.run_in_executor(executor, partial(limit_concurrency_sync, sample))) for sample in data]
146+
147+
status = ExperimentStatus.COMPLETED
148+
with tqdm(total=len(tasks), desc="Running samples", unit="sample") as pbar:
149+
try:
150+
for coro in asyncio.as_completed(tasks):
151+
try:
152+
await coro
153+
pbar.update(1)
154+
except Exception as e:
155+
print(f"\nExperiment stopped due to an error: {str(e)}\n")
156+
status = ExperimentStatus.FAILED
157+
for task in tasks:
158+
task.cancel()
159+
except asyncio.CancelledError:
160+
pass
148161

149162
await asyncio.sleep(0.2)
150163
total_evals = len(thread_ids_running_evals.get())
@@ -162,7 +175,7 @@ def limit_concurrency_sync(sample):
162175
else:
163176
dataset_level_eval_results = []
164177

165-
experiment_stats: ExperimentStatsSchema = p.finish_experiment(experiment_uuid, FinishExperimentRequestSchema(dataset_level_stats=dataset_level_eval_results))
178+
experiment_stats: ExperimentStatsSchema = p.finish_experiment(experiment_uuid, FinishExperimentRequestSchema(status=status, dataset_level_stats=dataset_level_eval_results))
166179
stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats)
167180
if dataset_level_eval_results:
168181
stat_name_to_avg_std.update(**{eval_result.name: eval_result.score for eval_result in dataset_level_eval_results})

parea/schemas/models.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,16 @@ class CreateTestCaseCollection(CreateTestCases):
329329
column_names: List[str] = field(factory=list)
330330

331331

332+
class ExperimentStatus(str, Enum):
333+
PENDING = "pending"
334+
RUNNING = "running"
335+
COMPLETED = "completed"
336+
FAILED = "failed"
337+
338+
332339
@define
333340
class FinishExperimentRequestSchema:
341+
status: ExperimentStatus
334342
dataset_level_stats: Optional[List[EvaluationResult]] = field(factory=list)
335343

336344

@@ -343,13 +351,6 @@ class ListExperimentUUIDsFilters:
343351
experiment_uuids: Optional[List[str]] = None
344352

345353

346-
class ExperimentStatus(str, Enum):
347-
PENDING = "pending"
348-
RUNNING = "running"
349-
COMPLETED = "completed"
350-
FAILED = "failed"
351-
352-
353354
class StatisticOperation(str, Enum):
354355
MEAN = "mean"
355356
MEDIAN = "median"

parea/utils/trace_integrations/instructor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,11 @@ def __call__(
7171
for key in ["max_retries", "response_model", "validation_context", "mode", "args"]:
7272
if kwargs.get(key):
7373
metadata[key] = kwargs[key]
74+
trace_name = "instructor"
75+
if "response_model" in kwargs and kwargs["response_model"] and hasattr(kwargs["response_model"], "__name__"):
76+
trace_name = kwargs["response_model"].__name__
7477
return trace(
75-
name="instructor",
78+
name=trace_name,
7679
overwrite_trace_id=trace_id,
7780
overwrite_inputs=inputs,
7881
metadata=metadata,

0 commit comments

Comments
 (0)