Skip to content

Commit 9ebf762

Browse files
authored
refactor: Run Update Mechanics (ENG-1296) (#53)
* Refactored run update mechanics for more data handling and higher frequency * Push an initial update right when the run starts * Remove unused _pending_artifact * Send full attribute data in final run span
1 parent c0a712b commit 9ebf762

File tree

4 files changed

+97
-46
lines changed

4 files changed

+97
-46
lines changed

dreadnode/api/util.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,13 @@ def process_run(run: RawRun) -> Run:
5959
for references, converted in ((run.inputs, inputs), (run.outputs, outputs)):
6060
for ref in references:
6161
if (_object := run.objects.get(ref.hash)) is None:
62-
logger.error("Object %s not found in run %s", ref.hash, run.id)
62+
if run.status != "pending": # In-progress runs may not have all the objects ready
63+
logger.error("Object %s not found in run %s", ref.hash, run.id)
6364
continue
6465

6566
if (_schema := run.object_schemas.get(_object.schema_hash)) is None:
66-
logger.error("Schema for object %s not found in run %s", ref.hash, run.id)
67+
if run.status != "pending":
68+
logger.error("Schema for object %s not found in run %s", ref.hash, run.id)
6769
continue
6870

6971
if isinstance(_object, RawObjectVal):
@@ -123,11 +125,13 @@ def process_task(task: RawTask, run: RawRun) -> Task:
123125
continue
124126

125127
if (_object := run.objects.get(ref.hash)) is None:
126-
logger.error("Object %s not found in run %s", ref.hash, run.id)
128+
if run.status != "pending":
129+
logger.error("Object %s not found in run %s", ref.hash, run.id)
127130
continue
128131

129132
if (_schema := run.object_schemas.get(_object.schema_hash)) is None:
130-
logger.error("Schema for object %s not found in run %s", ref.hash, run.id)
133+
if run.status != "pending":
134+
logger.error("Schema for object %s not found in run %s", ref.hash, run.id)
131135
continue
132136

133137
if isinstance(_object, RawObjectVal):

dreadnode/main.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def initialize(self) -> None:
230230
self._api.list_projects()
231231
except Exception as e:
232232
raise RuntimeError(
233-
"Failed to connect to the Dreadnode server.",
233+
f"Failed to connect to the Dreadnode server: {e}",
234234
) from e
235235

236236
headers = {"User-Agent": f"dreadnode/{VERSION}", "X-Api-Key": self.token}
@@ -707,23 +707,25 @@ def tag(self, *tag: str, to: ToObject = "task-or-run") -> None:
707707
@handle_internal_errors()
708708
def push_update(self) -> None:
709709
"""
710-
Push any pending metric or parameter data to the server.
710+
Push any pending run data to the server before run completion.
711711
712712
This is useful for ensuring that the UI is up to date with the
713-
latest data. Otherwise, all data for the run will be pushed
714-
automatically when the run is closed.
713+
latest data. Data is automatically pushed periodically, but
714+
you can call this method to force a push.
715715
716716
Example:
717717
```
718718
with dreadnode.run("my_run"):
719719
dreadnode.log_params(...)
720720
dreadnode.log_metric(...)
721721
dreadnode.push_update()
722+
723+
# do more work
722724
"""
723725
if (run := current_run_span.get()) is None:
724726
raise RuntimeError("Run updates must be pushed within a run")
725727

726-
run.push_update()
728+
run.push_update(force=True)
727729

728730
@handle_internal_errors()
729731
def log_param(

dreadnode/task.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
320320
metric = await scorer(output)
321321
span.log_metric(scorer.name, metric, origin=output)
322322

323+
# Trigger a run update whenever a task completes
324+
run.push_update()
325+
323326
return span
324327

325328
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:

dreadnode/tracing/span.py

Lines changed: 79 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import hashlib
22
import logging
3+
import time
34
import types
45
import typing as t
56
from contextvars import ContextVar, Token
@@ -233,19 +234,30 @@ def __init__(
233234
*,
234235
metrics: MetricDict | None = None,
235236
params: JsonDict | None = None,
236-
inputs: JsonDict | None = None,
237+
inputs: list[ObjectRef] | None = None,
238+
outputs: list[ObjectRef] | None = None,
239+
objects: dict[str, Object] | None = None,
240+
object_schemas: dict[str, JsonDict] | None = None,
237241
) -> None:
238242
attributes: AnyDict = {
239243
SPAN_ATTRIBUTE_RUN_ID: run_id,
240244
SPAN_ATTRIBUTE_PROJECT: project,
245+
**({SPAN_ATTRIBUTE_METRICS: metrics} if metrics else {}),
246+
**({SPAN_ATTRIBUTE_PARAMS: params} if params else {}),
247+
**({SPAN_ATTRIBUTE_INPUTS: inputs} if inputs else {}),
248+
**({SPAN_ATTRIBUTE_OUTPUTS: outputs} if outputs else {}),
249+
**({SPAN_ATTRIBUTE_OBJECTS: objects} if objects else {}),
250+
**({SPAN_ATTRIBUTE_OBJECT_SCHEMAS: object_schemas} if object_schemas else {}),
241251
}
242252

243-
if metrics:
244-
attributes[SPAN_ATTRIBUTE_METRICS] = metrics
245-
if params:
246-
attributes[SPAN_ATTRIBUTE_PARAMS] = params
247-
if inputs:
248-
attributes[SPAN_ATTRIBUTE_INPUTS] = inputs
253+
# Mark objects and schemas as large attributes if present
254+
if objects or object_schemas:
255+
large_attrs = []
256+
if objects:
257+
large_attrs.append(SPAN_ATTRIBUTE_OBJECTS)
258+
if object_schemas:
259+
large_attrs.append(SPAN_ATTRIBUTE_OBJECT_SCHEMAS)
260+
attributes[SPAN_ATTRIBUTE_LARGE_ATTRIBUTES] = large_attrs
249261

250262
super().__init__(f"run.{run_id}.update", attributes, tracer, type="run_update")
251263

@@ -265,8 +277,10 @@ def __init__(
265277
run_id: str | None = None,
266278
tags: t.Sequence[str] | None = None,
267279
autolog: bool = True,
280+
update_frequency: int = 5,
268281
) -> None:
269282
self.autolog = autolog
283+
self.project = project
270284

271285
self._params = params or {}
272286
self._metrics = metrics or {}
@@ -281,10 +295,16 @@ def __init__(
281295
storage=self._artifact_storage,
282296
prefix_path=prefix_path,
283297
)
284-
self.project = project
285298

286-
self._last_pushed_params = deepcopy(self._params)
287-
self._last_pushed_metrics = deepcopy(self._metrics)
299+
# Update mechanics
300+
self._last_update_time = time.time()
301+
self._update_frequency = update_frequency
302+
self._pending_params = deepcopy(self._params)
303+
self._pending_inputs = deepcopy(self._inputs)
304+
self._pending_outputs = deepcopy(self._outputs)
305+
self._pending_metrics = deepcopy(self._metrics)
306+
self._pending_objects = deepcopy(self._objects)
307+
self._pending_object_schemas = deepcopy(self._object_schemas)
288308

289309
self._context_token: Token[RunSpan | None] | None = None # contextvars context
290310
self._file_system = file_system
@@ -293,8 +313,6 @@ def __init__(
293313
attributes = {
294314
SPAN_ATTRIBUTE_RUN_ID: str(run_id or ULID()),
295315
SPAN_ATTRIBUTE_PROJECT: project,
296-
SPAN_ATTRIBUTE_PARAMS: self._params,
297-
SPAN_ATTRIBUTE_METRICS: self._metrics,
298316
**attributes,
299317
}
300318
super().__init__(name, attributes, tracer, type="run", tags=tags)
@@ -304,15 +322,20 @@ def __enter__(self) -> te.Self:
304322
raise RuntimeError("You cannot start a run span within another run")
305323

306324
self._context_token = current_run_span.set(self)
307-
return super().__enter__()
325+
span = super().__enter__()
326+
self.push_update(force=True)
327+
return span
308328

309329
def __exit__(
310330
self,
311331
exc_type: type[BaseException] | None,
312332
exc_value: BaseException | None,
313333
traceback: types.TracebackType | None,
314334
) -> None:
315-
self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._params)
335+
# When we finally close out the final span, include all the
336+
# full data attributes, so we can skip the update spans during
337+
# db queries later.
338+
self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._params, schema=False)
316339
self.set_attribute(SPAN_ATTRIBUTE_INPUTS, self._inputs, schema=False)
317340
self.set_attribute(SPAN_ATTRIBUTE_OUTPUTS, self._outputs, schema=False)
318341
self.set_attribute(SPAN_ATTRIBUTE_METRICS, self._metrics, schema=False)
@@ -335,32 +358,46 @@ def __exit__(
335358
if self._context_token is not None:
336359
current_run_span.reset(self._context_token)
337360

338-
def push_update(self) -> None:
361+
def push_update(self, *, force: bool = False) -> None:
339362
if self._span is None:
340363
return
341364

342-
metrics: MetricDict | None = None
343-
if self._last_pushed_metrics != self._metrics:
344-
metrics = self._metrics
345-
self._last_pushed_metrics = deepcopy(self._metrics)
346-
347-
params: JsonDict | None = None
348-
if self._last_pushed_params != self._params:
349-
params = self._params
350-
self._last_pushed_params = deepcopy(self._params)
365+
current_time = time.time()
366+
force_update = force or (current_time - self._last_update_time >= self._update_frequency)
367+
should_update = force_update and (
368+
self._pending_params
369+
or self._pending_inputs
370+
or self._pending_outputs
371+
or self._pending_metrics
372+
or self._pending_objects
373+
or self._pending_object_schemas
374+
)
351375

352-
if metrics is None and params is None:
376+
if not should_update:
353377
return
354378

355379
with RunUpdateSpan(
356380
run_id=self.run_id,
357381
project=self.project,
358382
tracer=self._tracer,
359-
params=params,
360-
metrics=metrics,
383+
metrics=self._pending_metrics if self._pending_metrics else None,
384+
params=self._pending_params if self._pending_params else None,
385+
inputs=self._pending_inputs if self._pending_inputs else None,
386+
outputs=self._pending_outputs if self._pending_outputs else None,
387+
objects=self._pending_objects if self._pending_objects else None,
388+
object_schemas=self._pending_object_schemas if self._pending_object_schemas else None,
361389
):
362390
pass
363391

392+
self._pending_metrics.clear()
393+
self._pending_params.clear()
394+
self._pending_inputs.clear()
395+
self._pending_outputs.clear()
396+
self._pending_objects.clear()
397+
self._pending_object_schemas.clear()
398+
399+
self._last_update_time = current_time
400+
364401
@property
365402
def run_id(self) -> str:
366403
return str(self.get_attribute(SPAN_ATTRIBUTE_RUN_ID, ""))
@@ -384,6 +421,7 @@ def log_object(
384421
# Store schema if new
385422
if schema_hash not in self._object_schemas:
386423
self._object_schemas[schema_hash] = serialized.schema
424+
self._pending_object_schemas[schema_hash] = serialized.schema
387425

388426
# Check if we already have this exact composite hash
389427
if composite_hash not in self._objects:
@@ -392,8 +430,7 @@ def log_object(
392430

393431
# Store with composite hash so we can look it up by the combination
394432
self._objects[composite_hash] = obj
395-
396-
object_ = self._objects[composite_hash]
433+
self._pending_objects[composite_hash] = obj
397434

398435
# Build event attributes, use composite hash in events
399436
event_attributes = {
@@ -407,7 +444,9 @@ def log_object(
407444
event_attributes[EVENT_ATTRIBUTE_OBJECT_LABEL] = label
408445

409446
self.log_event(name=event_name, attributes=event_attributes)
410-
return object_.hash
447+
self.push_update()
448+
449+
return composite_hash
411450

412451
def _store_file_by_hash(self, data: bytes, full_path: str) -> str:
413452
"""
@@ -488,9 +527,10 @@ def log_param(self, key: str, value: t.Any) -> None:
488527
def log_params(self, **params: t.Any) -> None:
489528
for key, value in params.items():
490529
self._params[key] = value
530+
self._pending_params[key] = value
491531

492-
# Always push updates for run params
493-
self.push_update()
532+
# Params should get pushed immediately
533+
self.push_update(force=True)
494534

495535
@property
496536
def inputs(self) -> AnyDict:
@@ -510,7 +550,9 @@ def log_input(
510550
label=label,
511551
event_name=EVENT_NAME_OBJECT_INPUT,
512552
)
513-
self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))
553+
object_ref = ObjectRef(name, label=label, hash=hash_, attributes=attributes)
554+
self._inputs.append(object_ref)
555+
self._pending_inputs.append(object_ref)
514556

515557
def log_artifact(
516558
self,
@@ -529,11 +571,8 @@ def log_artifact(
529571
Raises:
530572
FileNotFoundError: If the path doesn't exist
531573
"""
532-
533574
artifact_tree = self._artifact_tree_builder.process_artifact(local_uri)
534-
535575
self._artifact_merger.add_tree(artifact_tree)
536-
537576
self._artifacts = self._artifact_merger.get_merged_trees()
538577

539578
@property
@@ -601,6 +640,7 @@ def log_metric(
601640
if mode is not None:
602641
metric = metric.apply_mode(mode, metrics)
603642
metrics.append(metric)
643+
self._pending_metrics.setdefault(key, []).append(metric)
604644

605645
return metric
606646

@@ -622,7 +662,9 @@ def log_output(
622662
label=label,
623663
event_name=EVENT_NAME_OBJECT_OUTPUT,
624664
)
625-
self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))
665+
object_ref = ObjectRef(name, label=label, hash=hash_, attributes=attributes)
666+
self._outputs.append(object_ref)
667+
self._pending_outputs.append(object_ref)
626668

627669

628670
class TaskSpan(Span, t.Generic[R]):

0 commit comments

Comments
 (0)