11from __future__ import annotations
22
3+ import json
34import sys
45import traceback
6+ from collections import defaultdict
57from collections .abc import Mapping , Sequence
68from dataclasses import dataclass , field
79from threading import Lock
810from typing import TYPE_CHECKING , Any , Callable , cast
9- from weakref import WeakKeyDictionary , WeakSet
11+ from weakref import WeakKeyDictionary , WeakValueDictionary
1012
1113import opentelemetry .trace as trace_api
1214from opentelemetry import context as context_api
4345 ValidationError = None
4446
4547
46- OPEN_SPANS : WeakSet [ _LogfireWrappedSpan ] = WeakSet ()
48+ OPEN_SPANS : WeakValueDictionary [ tuple [ int , int ], _LogfireWrappedSpan ] = WeakValueDictionary ()
4749
4850
4951@dataclass
@@ -112,6 +114,23 @@ def force_flush(self, timeout_millis: int = 30000) -> bool:
112114 return True # pragma: no cover
113115
114116
117+ @dataclass
118+ class SpanMetric :
119+ details : dict [tuple [tuple [str , otel_types .AttributeValue ], ...], float ] = field (
120+ default_factory = lambda : defaultdict (int )
121+ )
122+
123+ def dump (self ):
124+ return {
125+ 'details' : [{'attributes' : dict (attributes ), 'total' : total } for attributes , total in self .details .items ()],
126+ 'total' : sum (total for total in self .details .values ()),
127+ }
128+
129+ def increment (self , attributes : Mapping [str , otel_types .AttributeValue ], value : float ):
130+ key = tuple (sorted (attributes .items ()))
131+ self .details [key ] += value
132+
133+
115134@dataclass (eq = False )
116135class _LogfireWrappedSpan (trace_api .Span , ReadableSpan ):
117136 """A span that wraps another span and overrides some behaviors in a logfire-specific way.
@@ -124,14 +143,24 @@ class _LogfireWrappedSpan(trace_api.Span, ReadableSpan):
124143
125144 span : Span
126145 ns_timestamp_generator : Callable [[], int ]
146+ record_metrics : bool
147+ metrics : dict [str , SpanMetric ] = field (default_factory = lambda : defaultdict (SpanMetric ))
127148
128149 def __post_init__ (self ):
129- OPEN_SPANS . add ( self )
150+ OPEN_SPANS [ self . _open_spans_key ()] = self
130151
131152 def end (self , end_time : int | None = None ) -> None :
132- OPEN_SPANS .discard (self )
153+ with handle_internal_errors :
154+ OPEN_SPANS .pop (self ._open_spans_key (), None )
155+ if self .metrics :
156+ self .span .set_attribute (
157+ 'logfire.metrics' , json .dumps ({name : metric .dump () for name , metric in self .metrics .items ()})
158+ )
133159 self .span .end (end_time or self .ns_timestamp_generator ())
134160
161+ def _open_spans_key (self ):
162+ return _open_spans_key (self .span .get_span_context ())
163+
135164 def get_span_context (self ) -> SpanContext :
136165 return self .span .get_span_context ()
137166
@@ -175,6 +204,14 @@ def record_exception(
175204 timestamp = timestamp or self .ns_timestamp_generator ()
176205 record_exception (self .span , exception , attributes = attributes , timestamp = timestamp , escaped = escaped )
177206
207+ def increment_metric (self , name : str , attributes : Mapping [str , otel_types .AttributeValue ], value : float ) -> None :
208+ if not self .is_recording () or not self .record_metrics :
209+ return
210+
211+ self .metrics [name ].increment (attributes , value )
212+ if self .parent and (parent := OPEN_SPANS .get (_open_spans_key (self .parent ))):
213+ parent .increment_metric (name , attributes , value )
214+
178215 def __exit__ (self , exc_type : type [BaseException ] | None , exc_value : BaseException | None , traceback : Any ) -> None :
179216 if self .is_recording ():
180217 if isinstance (exc_value , BaseException ):
@@ -187,6 +224,10 @@ def __getattr__(self, name: str) -> Any:
187224 return getattr (self .span , name )
188225
189226
227+ def _open_spans_key (ctx : SpanContext ) -> tuple [int , int ]:
228+ return ctx .trace_id , ctx .span_id
229+
230+
190231@dataclass
191232class _ProxyTracer (Tracer ):
192233 """A tracer that wraps another internal tracer allowing it to be re-assigned."""
@@ -216,7 +257,11 @@ def start_span(
216257 record_exception : bool = True ,
217258 set_status_on_exception : bool = True ,
218259 ) -> Span :
219- start_time = start_time or self .provider .config .advanced .ns_timestamp_generator ()
260+ config = self .provider .config
261+ ns_timestamp_generator = config .advanced .ns_timestamp_generator
262+ record_metrics : bool = not isinstance (config .metrics , (bool , type (None ))) and config .metrics .collect_in_spans
263+
264+ start_time = start_time or ns_timestamp_generator ()
220265
221266 # Make a copy of the attributes since this method can be called by arbitrary external code,
222267 # e.g. third party instrumentation.
@@ -241,7 +286,8 @@ def start_span(
241286 )
242287 return _LogfireWrappedSpan (
243288 span ,
244- ns_timestamp_generator = self .provider .config .advanced .ns_timestamp_generator ,
289+ ns_timestamp_generator = ns_timestamp_generator ,
290+ record_metrics = record_metrics ,
245291 )
246292
247293 # This means that `with start_as_current_span(...):`
0 commit comments