1- from typing import Any , Dict
1+ from typing import Any , Dict , Optional
22
33import json
4+ import logging
45import os
56
67from attrs import asdict , define , field
8+ from cattrs import structure
79
810from parea .api_client import HTTPClient
911from parea .constants import PAREA_OS_ENV_EXPERIMENT_UUID
1012from parea .helpers import serialize_metadata_values
1113from parea .schemas .log import TraceIntegrations
12- from parea .schemas .models import TraceLog , UpdateLog
14+ from parea .schemas .models import CreateGetProjectResponseSchema , TraceLog , UpdateLog
1315from parea .utils .universal_encoder import json_dumps
1416
17+ logger = logging .getLogger ()
18+
1519LOG_ENDPOINT = "/trace_log"
1620VENDOR_LOG_ENDPOINT = "/trace_log/{vendor}"
1721
2024class PareaLogger :
2125 _client : HTTPClient = field (init = False , default = None )
2226 _project_uuid : str = field (init = False , default = None )
27+ _project_name : str = field (init = False , default = None )
2328
2429 def set_client (self , client : HTTPClient ) -> None :
2530 self ._client = client
2631
27- def set_project_uuid (self , project_uuid : str ) -> None :
32+ def set_project_uuid (self , project_uuid : str , project_name : str ) -> None :
2833 self ._project_uuid = project_uuid
34+ self ._project_name = project_name
35+
36+ def _get_project_uuid (self ) -> Optional [str ]:
37+ if not self ._project_uuid :
38+ self ._project_uuid = self ._create_or_get_project (self ._project_name or "default" ).uuid
39+ try :
40+ return self ._project_uuid
41+ except Exception as e :
42+ logger .error (f"PareaLogger: Error getting project uuid for project { self ._project_name } : { e } " )
43+ return None
44+
45+ def _create_or_get_project (self , name : str ) -> CreateGetProjectResponseSchema :
46+ r = self ._client .request (
47+ "POST" ,
48+ "/project" ,
49+ data = {"name" : name },
50+ )
51+ return structure (r .json (), CreateGetProjectResponseSchema )
2952
3053 def update_log (self , data : UpdateLog ) -> None :
3154 data = serialize_metadata_values (data )
@@ -37,7 +60,7 @@ def update_log(self, data: UpdateLog) -> None:
3760
3861 def record_log (self , data : TraceLog ) -> None :
3962 data = serialize_metadata_values (data )
40- data .project_uuid = self ._project_uuid
63+ data .project_uuid = self ._get_project_uuid ()
4164 self ._client .request (
4265 "POST" ,
4366 LOG_ENDPOINT ,
@@ -60,7 +83,7 @@ def default_log(self, data: TraceLog) -> None:
6083 self .record_log (data )
6184
6285 def record_vendor_log (self , data : Dict [str , Any ], vendor : TraceIntegrations ) -> None :
63- data ["project_uuid" ] = self ._project_uuid
86+ data ["project_uuid" ] = self ._get_project_uuid ()
6487 if experiment_uuid := os .getenv (PAREA_OS_ENV_EXPERIMENT_UUID , None ):
6588 data ["experiment_uuid" ] = experiment_uuid
6689 self ._client .add_integration ("langchain" )
@@ -71,7 +94,7 @@ def record_vendor_log(self, data: Dict[str, Any], vendor: TraceIntegrations) ->
7194 )
7295
7396 async def arecord_vendor_log (self , data : Dict [str , Any ], vendor : TraceIntegrations ) -> None :
74- data ["project_uuid" ] = self ._project_uuid
97+ data ["project_uuid" ] = self ._get_project_uuid ()
7598 if experiment_uuid := os .getenv (PAREA_OS_ENV_EXPERIMENT_UUID , None ):
7699 data ["experiment_uuid" ] = experiment_uuid
77100 self ._client .add_integration ("langchain" )
0 commit comments