44import os
55
66from attrs import asdict , define , field
7+ from cattrs import structure
78
89from parea .api_client import HTTPClient
910from parea .constants import PAREA_OS_ENV_EXPERIMENT_UUID
1011from parea .helpers import serialize_metadata_values
1112from parea .schemas .log import TraceIntegrations
12- from parea .schemas .models import TraceLog , UpdateLog
13+ from parea .schemas .models import CreateGetProjectResponseSchema , TraceLog , UpdateLog
1314from parea .utils .universal_encoder import json_dumps
1415
1516LOG_ENDPOINT = "/trace_log"
2021class PareaLogger :
2122 _client : HTTPClient = field (init = False , default = None )
2223 _project_uuid : str = field (init = False , default = None )
24+ _project_name : str = field (init = False , default = None )
2325
2426 def set_client (self , client : HTTPClient ) -> None :
2527 self ._client = client
2628
27- def set_project_uuid (self , project_uuid : str ) -> None :
29+ def set_project_uuid (self , project_uuid : str , project_name : str ) -> None :
2830 self ._project_uuid = project_uuid
31+ self ._project_name = project_name
32+
33+ def _get_project_uuid (self ) -> str :
34+ if not self ._project_uuid :
35+ self ._project_uuid = self ._create_or_get_project (self ._project_name or "default" ).uuid
36+ return self ._project_uuid
37+
38+ def _create_or_get_project (self , name : str ) -> CreateGetProjectResponseSchema :
39+ r = self ._client .request (
40+ "POST" ,
41+ "/project" ,
42+ data = {"name" : name },
43+ )
44+ return structure (r .json (), CreateGetProjectResponseSchema )
2945
3046 def update_log (self , data : UpdateLog ) -> None :
3147 data = serialize_metadata_values (data )
@@ -37,7 +53,7 @@ def update_log(self, data: UpdateLog) -> None:
3753
3854 def record_log (self , data : TraceLog ) -> None :
3955 data = serialize_metadata_values (data )
40- data .project_uuid = self ._project_uuid
56+ data .project_uuid = self ._get_project_uuid ()
4157 self ._client .request (
4258 "POST" ,
4359 LOG_ENDPOINT ,
@@ -60,7 +76,7 @@ def default_log(self, data: TraceLog) -> None:
6076 self .record_log (data )
6177
6278 def record_vendor_log (self , data : Dict [str , Any ], vendor : TraceIntegrations ) -> None :
63- data ["project_uuid" ] = self ._project_uuid
79+ data ["project_uuid" ] = self ._get_project_uuid ()
6480 if experiment_uuid := os .getenv (PAREA_OS_ENV_EXPERIMENT_UUID , None ):
6581 data ["experiment_uuid" ] = experiment_uuid
6682 self ._client .add_integration ("langchain" )
@@ -71,7 +87,7 @@ def record_vendor_log(self, data: Dict[str, Any], vendor: TraceIntegrations) ->
7187 )
7288
7389 async def arecord_vendor_log (self , data : Dict [str , Any ], vendor : TraceIntegrations ) -> None :
74- data ["project_uuid" ] = self ._project_uuid
90+ data ["project_uuid" ] = self ._get_project_uuid ()
7591 if experiment_uuid := os .getenv (PAREA_OS_ENV_EXPERIMENT_UUID , None ):
7692 data ["experiment_uuid" ] = experiment_uuid
7793 self ._client .add_integration ("langchain" )
0 commit comments