11import hashlib
22import logging
3+ import time
34import types
45import typing as t
56from contextvars import ContextVar , Token
@@ -233,19 +234,30 @@ def __init__(
233234 * ,
234235 metrics : MetricDict | None = None ,
235236 params : JsonDict | None = None ,
236- inputs : JsonDict | None = None ,
237+ inputs : list [ObjectRef ] | None = None ,
238+ outputs : list [ObjectRef ] | None = None ,
239+ objects : dict [str , Object ] | None = None ,
240+ object_schemas : dict [str , JsonDict ] | None = None ,
237241 ) -> None :
238242 attributes : AnyDict = {
239243 SPAN_ATTRIBUTE_RUN_ID : run_id ,
240244 SPAN_ATTRIBUTE_PROJECT : project ,
245+ ** ({SPAN_ATTRIBUTE_METRICS : metrics } if metrics else {}),
246+ ** ({SPAN_ATTRIBUTE_PARAMS : params } if params else {}),
247+ ** ({SPAN_ATTRIBUTE_INPUTS : inputs } if inputs else {}),
248+ ** ({SPAN_ATTRIBUTE_OUTPUTS : outputs } if outputs else {}),
249+ ** ({SPAN_ATTRIBUTE_OBJECTS : objects } if objects else {}),
250+ ** ({SPAN_ATTRIBUTE_OBJECT_SCHEMAS : object_schemas } if object_schemas else {}),
241251 }
242252
243- if metrics :
244- attributes [SPAN_ATTRIBUTE_METRICS ] = metrics
245- if params :
246- attributes [SPAN_ATTRIBUTE_PARAMS ] = params
247- if inputs :
248- attributes [SPAN_ATTRIBUTE_INPUTS ] = inputs
253+ # Mark objects and schemas as large attributes if present
254+ if objects or object_schemas :
255+ large_attrs = []
256+ if objects :
257+ large_attrs .append (SPAN_ATTRIBUTE_OBJECTS )
258+ if object_schemas :
259+ large_attrs .append (SPAN_ATTRIBUTE_OBJECT_SCHEMAS )
260+ attributes [SPAN_ATTRIBUTE_LARGE_ATTRIBUTES ] = large_attrs
249261
250262 super ().__init__ (f"run.{ run_id } .update" , attributes , tracer , type = "run_update" )
251263
@@ -265,8 +277,10 @@ def __init__(
265277 run_id : str | None = None ,
266278 tags : t .Sequence [str ] | None = None ,
267279 autolog : bool = True ,
280+ update_frequency : int = 5 ,
268281 ) -> None :
269282 self .autolog = autolog
283+ self .project = project
270284
271285 self ._params = params or {}
272286 self ._metrics = metrics or {}
@@ -281,10 +295,16 @@ def __init__(
281295 storage = self ._artifact_storage ,
282296 prefix_path = prefix_path ,
283297 )
284- self .project = project
285298
286- self ._last_pushed_params = deepcopy (self ._params )
287- self ._last_pushed_metrics = deepcopy (self ._metrics )
299+ # Update mechanics
300+ self ._last_update_time = time .time ()
301+ self ._update_frequency = update_frequency
302+ self ._pending_params = deepcopy (self ._params )
303+ self ._pending_inputs = deepcopy (self ._inputs )
304+ self ._pending_outputs = deepcopy (self ._outputs )
305+ self ._pending_metrics = deepcopy (self ._metrics )
306+ self ._pending_objects = deepcopy (self ._objects )
307+ self ._pending_object_schemas = deepcopy (self ._object_schemas )
288308
289309 self ._context_token : Token [RunSpan | None ] | None = None # contextvars context
290310 self ._file_system = file_system
@@ -293,8 +313,6 @@ def __init__(
293313 attributes = {
294314 SPAN_ATTRIBUTE_RUN_ID : str (run_id or ULID ()),
295315 SPAN_ATTRIBUTE_PROJECT : project ,
296- SPAN_ATTRIBUTE_PARAMS : self ._params ,
297- SPAN_ATTRIBUTE_METRICS : self ._metrics ,
298316 ** attributes ,
299317 }
300318 super ().__init__ (name , attributes , tracer , type = "run" , tags = tags )
@@ -304,15 +322,20 @@ def __enter__(self) -> te.Self:
304322 raise RuntimeError ("You cannot start a run span within another run" )
305323
306324 self ._context_token = current_run_span .set (self )
307- return super ().__enter__ ()
325+ span = super ().__enter__ ()
326+ self .push_update (force = True )
327+ return span
308328
309329 def __exit__ (
310330 self ,
311331 exc_type : type [BaseException ] | None ,
312332 exc_value : BaseException | None ,
313333 traceback : types .TracebackType | None ,
314334 ) -> None :
315- self .set_attribute (SPAN_ATTRIBUTE_PARAMS , self ._params )
335+ # When we finally close out the final span, include all the
336+ # full data attributes, so we can skip the update spans during
337+ # db queries later.
338+ self .set_attribute (SPAN_ATTRIBUTE_PARAMS , self ._params , schema = False )
316339 self .set_attribute (SPAN_ATTRIBUTE_INPUTS , self ._inputs , schema = False )
317340 self .set_attribute (SPAN_ATTRIBUTE_OUTPUTS , self ._outputs , schema = False )
318341 self .set_attribute (SPAN_ATTRIBUTE_METRICS , self ._metrics , schema = False )
@@ -335,32 +358,46 @@ def __exit__(
335358 if self ._context_token is not None :
336359 current_run_span .reset (self ._context_token )
337360
338- def push_update (self ) -> None :
361+ def push_update (self , * , force : bool = False ) -> None :
339362 if self ._span is None :
340363 return
341364
342- metrics : MetricDict | None = None
343- if self ._last_pushed_metrics != self ._metrics :
344- metrics = self ._metrics
345- self ._last_pushed_metrics = deepcopy (self ._metrics )
346-
347- params : JsonDict | None = None
348- if self ._last_pushed_params != self ._params :
349- params = self ._params
350- self ._last_pushed_params = deepcopy (self ._params )
365+ current_time = time .time ()
366+ force_update = force or (current_time - self ._last_update_time >= self ._update_frequency )
367+ should_update = force_update and (
368+ self ._pending_params
369+ or self ._pending_inputs
370+ or self ._pending_outputs
371+ or self ._pending_metrics
372+ or self ._pending_objects
373+ or self ._pending_object_schemas
374+ )
351375
352- if metrics is None and params is None :
376+ if not should_update :
353377 return
354378
355379 with RunUpdateSpan (
356380 run_id = self .run_id ,
357381 project = self .project ,
358382 tracer = self ._tracer ,
359- params = params ,
360- metrics = metrics ,
383+ metrics = self ._pending_metrics if self ._pending_metrics else None ,
384+ params = self ._pending_params if self ._pending_params else None ,
385+ inputs = self ._pending_inputs if self ._pending_inputs else None ,
386+ outputs = self ._pending_outputs if self ._pending_outputs else None ,
387+ objects = self ._pending_objects if self ._pending_objects else None ,
388+ object_schemas = self ._pending_object_schemas if self ._pending_object_schemas else None ,
361389 ):
362390 pass
363391
392+ self ._pending_metrics .clear ()
393+ self ._pending_params .clear ()
394+ self ._pending_inputs .clear ()
395+ self ._pending_outputs .clear ()
396+ self ._pending_objects .clear ()
397+ self ._pending_object_schemas .clear ()
398+
399+ self ._last_update_time = current_time
400+
364401 @property
365402 def run_id (self ) -> str :
366403 return str (self .get_attribute (SPAN_ATTRIBUTE_RUN_ID , "" ))
@@ -384,6 +421,7 @@ def log_object(
384421 # Store schema if new
385422 if schema_hash not in self ._object_schemas :
386423 self ._object_schemas [schema_hash ] = serialized .schema
424+ self ._pending_object_schemas [schema_hash ] = serialized .schema
387425
388426 # Check if we already have this exact composite hash
389427 if composite_hash not in self ._objects :
@@ -392,8 +430,7 @@ def log_object(
392430
393431 # Store with composite hash so we can look it up by the combination
394432 self ._objects [composite_hash ] = obj
395-
396- object_ = self ._objects [composite_hash ]
433+ self ._pending_objects [composite_hash ] = obj
397434
398435 # Build event attributes, use composite hash in events
399436 event_attributes = {
@@ -407,7 +444,9 @@ def log_object(
407444 event_attributes [EVENT_ATTRIBUTE_OBJECT_LABEL ] = label
408445
409446 self .log_event (name = event_name , attributes = event_attributes )
410- return object_ .hash
447+ self .push_update ()
448+
449+ return composite_hash
411450
412451 def _store_file_by_hash (self , data : bytes , full_path : str ) -> str :
413452 """
@@ -488,9 +527,10 @@ def log_param(self, key: str, value: t.Any) -> None:
488527 def log_params (self , ** params : t .Any ) -> None :
489528 for key , value in params .items ():
490529 self ._params [key ] = value
530+ self ._pending_params [key ] = value
491531
492- # Always push updates for run params
493- self .push_update ()
532+ # Params should get pushed immediately
533+ self .push_update (force = True )
494534
495535 @property
496536 def inputs (self ) -> AnyDict :
@@ -510,7 +550,9 @@ def log_input(
510550 label = label ,
511551 event_name = EVENT_NAME_OBJECT_INPUT ,
512552 )
513- self ._inputs .append (ObjectRef (name , label = label , hash = hash_ , attributes = attributes ))
553+ object_ref = ObjectRef (name , label = label , hash = hash_ , attributes = attributes )
554+ self ._inputs .append (object_ref )
555+ self ._pending_inputs .append (object_ref )
514556
515557 def log_artifact (
516558 self ,
@@ -529,11 +571,8 @@ def log_artifact(
529571 Raises:
530572 FileNotFoundError: If the path doesn't exist
531573 """
532-
533574 artifact_tree = self ._artifact_tree_builder .process_artifact (local_uri )
534-
535575 self ._artifact_merger .add_tree (artifact_tree )
536-
537576 self ._artifacts = self ._artifact_merger .get_merged_trees ()
538577
539578 @property
@@ -601,6 +640,7 @@ def log_metric(
601640 if mode is not None :
602641 metric = metric .apply_mode (mode , metrics )
603642 metrics .append (metric )
643+ self ._pending_metrics .setdefault (key , []).append (metric )
604644
605645 return metric
606646
@@ -622,7 +662,9 @@ def log_output(
622662 label = label ,
623663 event_name = EVENT_NAME_OBJECT_OUTPUT ,
624664 )
625- self ._outputs .append (ObjectRef (name , label = label , hash = hash_ , attributes = attributes ))
665+ object_ref = ObjectRef (name , label = label , hash = hash_ , attributes = attributes )
666+ self ._outputs .append (object_ref )
667+ self ._pending_outputs .append (object_ref )
626668
627669
628670class TaskSpan (Span , t .Generic [R ]):
0 commit comments