Skip to content

Commit 3c219ce

Browse files
authored
Merge pull request #775 from parea-ai/PAI-1063-extend-experiment-list-endpoint
feat: extend experiments list endpoint
2 parents f2a4c29 + 55a759f commit 3c219ce

File tree

4 files changed

+54
-8
lines changed

4 files changed

+54
-8
lines changed

parea/client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
CreateTestCases,
2626
ExperimentSchema,
2727
ExperimentStatsSchema,
28+
ExperimentWithPinnedStatsSchema,
2829
FeedbackRequest,
2930
FinishExperimentRequestSchema,
3031
ListExperimentUUIDsFilters,
@@ -350,13 +351,13 @@ async def aget_trace_log(self, trace_id: str) -> TraceLog:
350351
response = await self._client.request_async("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
351352
return structure_trace_log_from_api(response.json())
352353

353-
def list_experiment_uuids(self, filter_conditions: Optional[ListExperimentUUIDsFilters] = ListExperimentUUIDsFilters()) -> List[str]:
354+
def list_experiments(self, filter_conditions: Optional[ListExperimentUUIDsFilters] = ListExperimentUUIDsFilters()) -> List[ExperimentWithPinnedStatsSchema]:
354355
response = self._client.request("POST", LIST_EXPERIMENTS_ENDPOINT, data=asdict(filter_conditions))
355-
return response.json()
356+
return structure(response.json(), List[ExperimentWithPinnedStatsSchema])
356357

357-
async def alist_experiment_uuids(self, filter_conditions: Optional[ListExperimentUUIDsFilters] = ListExperimentUUIDsFilters()) -> List[str]:
358+
async def alist_experiments(self, filter_conditions: Optional[ListExperimentUUIDsFilters] = ListExperimentUUIDsFilters()) -> List[ExperimentWithPinnedStatsSchema]:
358359
response = await self._client.request_async("POST", LIST_EXPERIMENTS_ENDPOINT, data=asdict(filter_conditions))
359-
return response.json()
360+
return structure(response.json(), List[ExperimentWithPinnedStatsSchema])
360361

361362
def get_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLog]:
362363
response = self._client.request("POST", GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment_uuid), data=asdict(filters))

parea/cookbook/list_experiments.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
1111

12-
exp_uuids = p.list_experiment_uuids(ListExperimentUUIDsFilters(experiment_name_filter="Greeting"))
13-
print(f"Num. experiments: {len(exp_uuids)}")
14-
trace_logs = p.get_experiment_trace_logs(exp_uuids[0])
12+
experiments = p.list_experiments(ListExperimentUUIDsFilters(experiment_name_filter="Greeting"))
13+
print(f"Num. experiments: {len(experiments)}")
14+
trace_logs = p.get_experiment_trace_logs(experiments[0].uuid)
1515
print(f"Num. trace logs: {len(trace_logs)}")
16+
print(f"Trace log: {trace_logs[0]}")

parea/schemas/models.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,50 @@ class ListExperimentUUIDsFilters:
309309
run_name_filter: Optional[str] = None
310310

311311

312+
class ExperimentStatus(str, Enum):
313+
PENDING = "pending"
314+
RUNNING = "running"
315+
COMPLETED = "completed"
316+
FAILED = "failed"
317+
318+
319+
class StatisticOperation(str, Enum):
320+
MEAN = "mean"
321+
MEDIAN = "median"
322+
VARIANCE = "variance"
323+
STANDARD_DEVIATION = "standard_deviation"
324+
MIN = "min"
325+
MAX = "max"
326+
MSE = "mse"
327+
MAE = "mae"
328+
CORRELATION = "correlation"
329+
SPEARMAN_CORRELATION = "spearman_correlation"
330+
ACCURACY = "accuracy"
331+
CUSTOM = "custom"
332+
333+
334+
@define
335+
class ExperimentPinnedStatistic:
336+
var1: str
337+
operation: StatisticOperation
338+
value: float
339+
var2: Optional[str] = None
340+
341+
342+
@define
343+
class ExperimentWithPinnedStatsSchema:
344+
name: str
345+
uuid: str
346+
created_at: str
347+
run_name: str
348+
project_uuid: str
349+
status: ExperimentStatus
350+
is_public: bool = False
351+
metadata: Optional[Dict[str, str]] = None
352+
pinned_stats: list[ExperimentPinnedStatistic] = []
353+
num_samples: Optional[int] = None
354+
355+
312356
class FilterOperator(str, Enum):
313357
EQUALS = "equals"
314358
NOT_EQUALS = "not_equals"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
66
[tool.poetry]
77
name = "parea-ai"
88
packages = [{ include = "parea" }]
9-
version = "0.2.134"
9+
version = "0.2.135"
1010
description = "Parea python sdk"
1111
readme = "README.md"
1212
authors = ["joel-parea-ai <[email protected]>"]

0 commit comments

Comments
 (0)