Skip to content

Commit b0741ad

Browse files
authored
Merge pull request #935 from parea-ai/PAI-1264-finished-experiment-even-on-trace-fail-flag-to-fail-on-first
Pai 1264 finished experiment even on trace fail flag to fail on first
2 parents e3f0c63 + 84a9654 commit b0741ad

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

parea/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def experiment(
352352
metadata: Optional[Dict[str, str]] = None,
353353
dataset_level_evals: Optional[List[Callable]] = None,
354354
n_workers: int = 10,
355+
stop_on_error: bool = False,
355356
):
356357
"""
357358
:param data: If your dataset is defined locally it should be an iterable of k/v
@@ -363,6 +364,7 @@ def experiment(
363364
:param metadata: Optional metadata to attach to the experiment.
364365
:param dataset_level_evals: Optional list of functions to run on the dataset level. Each function should accept a list of EvaluatedLog objects and return a float or an EvaluationResult object
365366
:param n_workers: The number of workers to use for running the experiment.
367+
:param stop_on_error: If True, the experiment will stop on the first exception. If False, the experiment will continue running the remaining samples.
366368
"""
367369
from parea import Experiment
368370

@@ -375,6 +377,7 @@ def experiment(
375377
metadata=metadata,
376378
dataset_level_evals=dataset_level_evals,
377379
n_workers=n_workers,
380+
stop_on_error=stop_on_error,
378381
)
379382

380383
def _update_data_and_trace(self, data: Completion) -> Completion:

parea/experiment/experiment.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ async def experiment(
9393
n_trials: int = 1,
9494
dataset_level_evals: Optional[List[Callable]] = None,
9595
n_workers: int = 10,
96+
stop_on_error: bool = False,
9697
) -> ExperimentStatsSchema:
9798
"""Creates an experiment and runs the function on the data iterator.
9899
param experiment_name: The name of the experiment. Used to organize experiments within a project.
@@ -107,6 +108,7 @@ async def experiment(
107108
param dataset_level_evals: A list of functions to run on the dataset level. Each function should accept a list of EvaluatedLogs and return a float or a
108109
EvaluationResult. If a float is returned, the name of the function will be used as the name of the evaluation.
109110
param n_workers: The number of workers to use for running the experiment.
111+
param stop_on_error: If True, the experiment will stop running if an exception is raised.
110112
"""
111113
if isinstance(data, (str, int)):
112114
print(f"Fetching test collection: {data}")
@@ -152,10 +154,13 @@ def limit_concurrency_sync(sample):
152154
await coro
153155
pbar.update(1)
154156
except Exception as e:
155-
print(f"\nExperiment stopped due to an error: {str(e)}\n")
156157
status = ExperimentStatus.FAILED
157-
for task in tasks:
158-
task.cancel()
158+
if stop_on_error:
159+
print(f"\nExperiment stopped due to an error: {str(e)}\n")
160+
for task in tasks:
161+
task.cancel()
162+
else:
163+
pbar.update(1)
159164
except asyncio.CancelledError:
160165
pass
161166

@@ -220,6 +225,7 @@ class Experiment:
220225
n_workers: int = field(default=10)
221226
# The number of times to run the experiment on the same data.
222227
n_trials: int = field(default=1)
228+
stop_on_error: bool = field(default=False)
223229

224230
def __attrs_post_init__(self):
225231
global _experiments
@@ -253,7 +259,18 @@ def run(self, run_name: Optional[str] = None) -> None:
253259
experiment_schema: ExperimentSchema = self.p.create_experiment(CreateExperimentRequest(name=self.experiment_name, run_name=self.run_name, metadata=self.metadata))
254260
self.experiment_uuid = experiment_schema.uuid
255261
self.experiment_stats = asyncio.run(
256-
experiment(self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers)
262+
experiment(
263+
self.experiment_name,
264+
self.run_name,
265+
self.data,
266+
self.func,
267+
self.p,
268+
self.experiment_uuid,
269+
self.n_trials,
270+
self.dataset_level_evals,
271+
self.n_workers,
272+
self.stop_on_error,
273+
)
257274
)
258275
except Exception as e:
259276
import traceback
@@ -277,7 +294,7 @@ async def arun(self, run_name: Optional[str] = None) -> None:
277294
)
278295
self.experiment_uuid = experiment_schema.uuid
279296
self.experiment_stats = await experiment(
280-
self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers
297+
self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers, self.stop_on_error
281298
)
282299
except Exception as e:
283300
import traceback

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
66
[tool.poetry]
77
name = "parea-ai"
88
packages = [{ include = "parea" }]
9-
version = "0.2.175"
9+
version = "0.2.176"
1010
description = "Parea python sdk"
1111
readme = "README.md"
1212
authors = ["joel-parea-ai <[email protected]>"]

0 commit comments

Comments
 (0)