|
14 | 14 | from parea.cache.cache import Cache |
15 | 15 | from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID |
16 | 16 | from parea.experiment.datasets import create_test_cases, create_test_collection |
17 | | -from parea.helpers import gen_trace_id, serialize_metadata_values |
| 17 | +from parea.helpers import gen_trace_id, serialize_metadata_values, structure_trace_log_from_api |
18 | 18 | from parea.parea_logger import parea_logger |
19 | 19 | from parea.schemas.models import ( |
20 | 20 | Completion, |
|
29 | 29 | FinishExperimentRequestSchema, |
30 | 30 | ProjectSchema, |
31 | 31 | TestCaseCollection, |
| 32 | + TraceLog, |
32 | 33 | UseDeployedPrompt, |
33 | 34 | UseDeployedPromptResponse, |
34 | 35 | ) |
|
48 | 49 | GET_COLLECTION_ENDPOINT = "/collection/{test_collection_identifier}" |
49 | 50 | CREATE_COLLECTION_ENDPOINT = "/collection" |
50 | 51 | ADD_TEST_CASES_ENDPOINT = "/testcases" |
| 52 | +GET_TRACE_LOG_ENDPOINT = "/trace_log/{trace_id}" |
51 | 53 |
|
52 | 54 |
|
53 | 55 | @define |
@@ -336,6 +338,14 @@ def _update_data_and_trace(self, data: Completion) -> Completion: |
336 | 338 |
|
337 | 339 | return data |
338 | 340 |
|
| 341 | + def get_trace_log(self, trace_id: str) -> TraceLog: |
| 342 | + response = self._client.request("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id)) |
| 343 | + return structure_trace_log_from_api(response.json()) |
| 344 | + |
| 345 | + async def aget_trace_log(self, trace_id: str) -> TraceLog: |
| 346 | + response = await self._client.request_async("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id)) |
| 347 | + return structure_trace_log_from_api(response.json()) |
| 348 | + |
339 | 349 |
|
340 | 350 | _initialized_parea_wrapper = False |
341 | 351 |
|
|
0 commit comments