|
15 | 15 | from typing import List, Optional, Union, Dict, Any |
16 | 16 | from agentops.client import Client |
17 | 17 | from agentops.sdk.core import TraceContext, tracer |
18 | | -from agentops.sdk.decorators import trace, session, agent, task, workflow, operation |
| 18 | +from agentops.sdk.decorators import trace, session, agent, task, workflow, operation, tool |
| 19 | +from agentops.enums import TraceState, SUCCESS, ERROR, UNSET |
| 20 | +from opentelemetry.trace.status import StatusCode |
19 | 21 |
|
20 | 22 | from agentops.logging.config import logger |
| 23 | +import threading |
21 | 24 |
|
22 | | -# Client global instance; one per process runtime |
23 | | -_client = Client() |
| 25 | +# Thread-safe client management |
| 26 | +_client_lock = threading.Lock() |
| 27 | +_client = None |
24 | 28 |
|
25 | 29 |
|
26 | 30 | def get_client() -> Client: |
27 | | - """Get the singleton client instance""" |
| 31 | + """Get the singleton client instance in a thread-safe manner""" |
28 | 32 | global _client |
29 | 33 |
|
| 34 | + # Double-checked locking pattern for thread safety |
| 35 | + if _client is None: |
| 36 | + with _client_lock: |
| 37 | + if _client is None: |
| 38 | + _client = Client() |
| 39 | + |
30 | 40 | return _client |
31 | 41 |
|
32 | 42 |
|
@@ -106,24 +116,31 @@ def init( |
106 | 116 | elif default_tags: |
107 | 117 | merged_tags = default_tags |
108 | 118 |
|
109 | | - return _client.init( |
110 | | - api_key=api_key, |
111 | | - endpoint=endpoint, |
112 | | - app_url=app_url, |
113 | | - max_wait_time=max_wait_time, |
114 | | - max_queue_size=max_queue_size, |
115 | | - default_tags=merged_tags, |
116 | | - trace_name=trace_name, |
117 | | - instrument_llm_calls=instrument_llm_calls, |
118 | | - auto_start_session=auto_start_session, |
119 | | - auto_init=auto_init, |
120 | | - skip_auto_end_session=skip_auto_end_session, |
121 | | - env_data_opt_out=env_data_opt_out, |
122 | | - log_level=log_level, |
123 | | - fail_safe=fail_safe, |
124 | | - exporter_endpoint=exporter_endpoint, |
| 119 | + # Prepare initialization arguments |
| 120 | + init_kwargs = { |
| 121 | + "api_key": api_key, |
| 122 | + "endpoint": endpoint, |
| 123 | + "app_url": app_url, |
| 124 | + "max_wait_time": max_wait_time, |
| 125 | + "max_queue_size": max_queue_size, |
| 126 | + "default_tags": merged_tags, |
| 127 | + "trace_name": trace_name, |
| 128 | + "instrument_llm_calls": instrument_llm_calls, |
| 129 | + "auto_start_session": auto_start_session, |
| 130 | + "auto_init": auto_init, |
| 131 | + "skip_auto_end_session": skip_auto_end_session, |
| 132 | + "env_data_opt_out": env_data_opt_out, |
| 133 | + "log_level": log_level, |
| 134 | + "fail_safe": fail_safe, |
| 135 | + "exporter_endpoint": exporter_endpoint, |
125 | 136 | **kwargs, |
126 | | - ) |
| 137 | + } |
| 138 | + |
| 139 | + # Get the current client instance (creates new one if needed) |
| 140 | + client = get_client() |
| 141 | + |
| 142 | + # Initialize the client directly |
| 143 | + return client.init(**init_kwargs) |
127 | 144 |
|
128 | 145 |
|
129 | 146 | def configure(**kwargs): |
@@ -173,7 +190,8 @@ def configure(**kwargs): |
173 | 190 | if invalid_params: |
174 | 191 | logger.warning(f"Invalid configuration parameters: {invalid_params}") |
175 | 192 |
|
176 | | - _client.configure(**kwargs) |
| 193 | + client = get_client() |
| 194 | + client.configure(**kwargs) |
177 | 195 |
|
178 | 196 |
|
179 | 197 | def start_trace( |
@@ -207,7 +225,9 @@ def start_trace( |
207 | 225 | return tracer.start_trace(trace_name=trace_name, tags=tags) |
208 | 226 |
|
209 | 227 |
|
210 | | -def end_trace(trace_context: Optional[TraceContext] = None, end_state: str = "Success") -> None: |
| 228 | +def end_trace( |
| 229 | + trace_context: Optional[TraceContext] = None, end_state: Union[TraceState, StatusCode, str] = TraceState.SUCCESS |
| 230 | +) -> None: |
211 | 231 | """ |
212 | 232 | Ends a trace (its root span) and finalizes it. |
213 | 233 | If no trace_context is provided, ends all active session spans. |
@@ -246,4 +266,12 @@ def end_trace(trace_context: Optional[TraceContext] = None, end_state: str = "Su |
246 | 266 | "workflow", |
247 | 267 | "operation", |
248 | 268 | "tracer", |
| 269 | + "tool", |
| 270 | + # Trace state enums |
| 271 | + "TraceState", |
| 272 | + "SUCCESS", |
| 273 | + "ERROR", |
| 274 | + "UNSET", |
| 275 | + # OpenTelemetry status codes (for advanced users) |
| 276 | + "StatusCode", |
249 | 277 | ] |
0 commit comments