Skip to content

Commit 91a03ab

Browse files
authored
Fix dependencies. Add get_run_context and continue_run features. (#67)
1 parent 32ee5d6 commit 91a03ab

File tree

10 files changed

+770
-55
lines changed

10 files changed

+770
-55
lines changed

.pre-commit-config.yaml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ repos:
3131
rev: v2.4.1
3232
hooks:
3333
- id: codespell
34-
entry: codespell -q 3 -f --skip=".git,.github,README.md" --ignore-words-list="astroid"
34+
entry: codespell -q 3 -f --skip=".git,.github,README.md" --ignore-words-list="astroid,braket,te"
3535

3636
# Python code security
3737
- repo: https://github.com/PyCQA/bandit
@@ -57,21 +57,21 @@ repos:
5757
- id: nbstripout
5858
args: [--keep-id]
5959

60-
- repo: https://github.com/astral-sh/ruff-pre-commit
61-
rev: v0.11.7
62-
hooks:
63-
- id: ruff
64-
args: [--fix]
65-
- id: ruff-format
60+
# - repo: https://github.com/astral-sh/ruff-pre-commit
61+
# rev: v0.11.7
62+
# hooks:
63+
# - id: ruff
64+
# args: [--fix]
65+
# - id: ruff-format
6666

67-
- repo: https://github.com/pre-commit/mirrors-mypy
68-
rev: v1.15.0
69-
hooks:
70-
- id: mypy
71-
additional_dependencies:
72-
- "types-PyYAML"
73-
- "types-requests"
74-
- "types-setuptools"
67+
# - repo: https://github.com/pre-commit/mirrors-mypy
68+
# rev: v1.15.0
69+
# hooks:
70+
# - id: mypy
71+
# additional_dependencies:
72+
# - "types-PyYAML"
73+
# - "types-requests"
74+
# - "types-setuptools"
7575

7676
- repo: local
7777
hooks:

dreadnode/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
task_span = DEFAULT_INSTANCE.task_span
1919
push_update = DEFAULT_INSTANCE.push_update
2020
tag = DEFAULT_INSTANCE.tag
21+
get_run_context = DEFAULT_INSTANCE.get_run_context
22+
continue_run = DEFAULT_INSTANCE.continue_run
2123

2224
log_metric = DEFAULT_INSTANCE.log_metric
2325
log_param = DEFAULT_INSTANCE.log_param

dreadnode/data_types/audio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _process_audio_data(self) -> tuple[bytes, str, int | None, float | None]:
6363
Returns:
6464
A tuple of (audio_bytes, format_name, sample_rate, duration)
6565
"""
66-
if isinstance(self._data, (str, Path)) and Path(self._data).exists():
66+
if isinstance(self._data, str | Path) and Path(self._data).exists():
6767
return self._process_file_path()
6868
if isinstance(self._data, np.ndarray):
6969
return self._process_numpy_array()
@@ -159,7 +159,7 @@ def _generate_metadata(
159159
"x-python-datatype": "dreadnode.Audio.bytes",
160160
}
161161

162-
if isinstance(self._data, (str, Path)):
162+
if isinstance(self._data, str | Path):
163163
metadata["source-type"] = "file"
164164
metadata["source-path"] = str(self._data)
165165
elif isinstance(self._data, np.ndarray):

dreadnode/data_types/py.typed

Whitespace-only changes.

dreadnode/main.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from logfire._internal.exporters.remove_pending import RemovePendingSpansExporter
1717
from logfire._internal.stack_info import get_filepath_attribute, warn_at_user_stacklevel
1818
from logfire._internal.utils import safe_repr
19+
from opentelemetry import propagate
1920
from opentelemetry.exporter.otlp.proto.http import Compression
2021
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
2122
from opentelemetry.sdk.trace.export import BatchSpanProcessor
@@ -39,6 +40,7 @@
3940
FileSpanExporter,
4041
)
4142
from dreadnode.tracing.span import (
43+
RunContext,
4244
RunSpan,
4345
Span,
4446
TaskSpan,
@@ -154,7 +156,7 @@ def configure(
154156
server: The Dreadnode server URL.
155157
token: The Dreadnode API token.
156158
local_dir: The local directory to store data in.
157-
project: The defautlt project name to associate all runs with.
159+
project: The default project name to associate all runs with.
158160
service_name: The service name to use for OpenTelemetry.
159161
service_version: The service version to use for OpenTelemetry.
160162
console: Whether to log span information to the console.
@@ -198,7 +200,7 @@ def initialize(self) -> None:
198200
metric_readers: list[MetricReader] = []
199201

200202
self.server = self.server or (DEFAULT_SERVER_URL if self.token else None)
201-
if not (self.server and self.token and self.local_dir):
203+
if not (self.server or self.token or self.local_dir):
202204
warn_at_user_stacklevel(
203205
"Your current configuration won't persist run data anywhere. "
204206
"Use `dreadnode.init(server=..., token=...)`, `dreadnode.init(local_dir=...)`, "
@@ -280,6 +282,7 @@ def initialize(self) -> None:
280282
console=logfire.ConsoleOptions() if self.console is True else self.console,
281283
scrubbing=False,
282284
inspect_arguments=False,
285+
distributed_tracing=False,
283286
)
284287
self._logfire.config.ignore_no_config = True
285288

@@ -660,12 +663,16 @@ def run(
660663
run will be associated with a default project.
661664
autolog: Whether to automatically log task inputs, outputs, and execution metrics if unspecified.
662665
**attributes: Additional attributes to attach to the run span.
666+
667+
Returns:
668+
A RunSpan object that can be used as a context manager.
669+
The run will automatically be completed when the context manager exits.
663670
"""
664671
if not self._initialized:
665672
self.initialize()
666673

667674
if name is None:
668-
name = f"{coolname.generate_slug(2)}-{random.randint(100, 999)}" # noqa: S311
675+
name = f"{coolname.generate_slug(2)}-{random.randint(100, 999)}" # noqa: S311 # nosec
669676

670677
return RunSpan(
671678
name=name,
@@ -679,6 +686,52 @@ def run(
679686
autolog=autolog,
680687
)
681688

689+
def get_run_context(self) -> RunContext:
690+
"""
691+
Capture the current run context for transfer to another host, thread, or process.
692+
693+
Use `continue_run()` to continue the run anywhere else.
694+
695+
Returns:
696+
RunContext containing run state and trace propagation headers.
697+
698+
Raises:
699+
RuntimeError: If called outside of an active run.
700+
"""
701+
if (run := current_run_span.get()) is None:
702+
raise RuntimeError("get_run_context() must be called within a run")
703+
704+
# Capture OpenTelemetry trace context
705+
trace_context: dict[str, str] = {}
706+
propagate.inject(trace_context)
707+
708+
return {
709+
"run_id": run.run_id,
710+
"run_name": run.name,
711+
"project": run.project,
712+
"trace_context": trace_context,
713+
}
714+
715+
def continue_run(self, run_context: RunContext) -> RunSpan:
716+
"""
717+
Continue a run from captured context on a remote host.
718+
719+
Args:
720+
run_context: The RunContext captured from get_run_context().
721+
722+
Returns:
723+
A RunSpan object that can be used as a context manager.
724+
"""
725+
if not self._initialized:
726+
self.initialize()
727+
728+
return RunSpan.from_context(
729+
context=run_context,
730+
tracer=self._get_tracer(),
731+
file_system=self._fs,
732+
prefix_path=self._fs_prefix,
733+
)
734+
682735
def tag(self, *tag: str, to: ToObject = "task-or-run") -> None:
683736
"""
684737
Add one or many tags to the current task or run.

dreadnode/task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
236236
Returns:
237237
The span associated with task execution.
238238
"""
239-
run = current_run_span.get()
240-
if run is None or not run.is_recording:
239+
240+
if (run := current_run_span.get()) is None:
241241
raise RuntimeError("Tasks must be executed within a run")
242242

243243
log_inputs = run.autolog if isinstance(self.log_inputs, Inherited) else self.log_inputs

dreadnode/tracing/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
SPAN_NAMESPACE = "dreadnode"
44

5-
SpanType = t.Literal["run", "task", "span", "run_update"]
5+
SpanType = t.Literal["run", "task", "span", "run_update", "run_fragment"]
66

77
SPAN_ATTRIBUTE_VERSION = f"{SPAN_NAMESPACE}.version"
88
SPAN_ATTRIBUTE_TYPE = f"{SPAN_NAMESPACE}.type"

dreadnode/tracing/span.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from logfire._internal.tracer import OPEN_SPANS
2020
from logfire._internal.utils import uniquify_sequence
2121
from opentelemetry import context as context_api
22+
from opentelemetry import propagate
2223
from opentelemetry import trace as trace_api
2324
from opentelemetry.sdk.trace import ReadableSpan
2425
from opentelemetry.trace import Tracer
@@ -225,6 +226,15 @@ def log_event(
225226
)
226227

227228

229+
class RunContext(te.TypedDict):
230+
"""Context for transferring and continuing runs in other places."""
231+
232+
run_id: str
233+
run_name: str
234+
project: str
235+
trace_context: dict[str, str]
236+
237+
228238
class RunUpdateSpan(Span):
229239
def __init__(
230240
self,
@@ -274,10 +284,11 @@ def __init__(
274284
*,
275285
params: AnyDict | None = None,
276286
metrics: MetricDict | None = None,
277-
run_id: str | None = None,
278287
tags: t.Sequence[str] | None = None,
279288
autolog: bool = True,
280289
update_frequency: int = 5,
290+
run_id: str | ULID | None = None,
291+
type: SpanType = "run",
281292
) -> None:
282293
self.autolog = autolog
283294
self.project = project
@@ -307,6 +318,8 @@ def __init__(
307318
self._pending_object_schemas = deepcopy(self._object_schemas)
308319

309320
self._context_token: Token[RunSpan | None] | None = None # contextvars context
321+
self._remote_context: dict[str, str] | None = None # remote run trace context
322+
self._remote_token: object | None = None
310323
self._file_system = file_system
311324
self._prefix_path = prefix_path
312325

@@ -315,23 +328,55 @@ def __init__(
315328
SPAN_ATTRIBUTE_PROJECT: project,
316329
**attributes,
317330
}
318-
super().__init__(name, attributes, tracer, type="run", tags=tags)
331+
super().__init__(name, attributes, tracer, type=type, tags=tags)
332+
333+
@classmethod
334+
def from_context(
335+
cls,
336+
context: RunContext,
337+
tracer: Tracer,
338+
file_system: AbstractFileSystem,
339+
prefix_path: str,
340+
) -> "RunSpan":
341+
self = RunSpan(
342+
name=f"run.{context['run_id']}.fragment",
343+
project=context["project"],
344+
attributes={},
345+
tracer=tracer,
346+
file_system=file_system,
347+
prefix_path=prefix_path,
348+
type="run_fragment",
349+
run_id=context["run_id"],
350+
)
351+
352+
self._remote_context = context["trace_context"]
353+
354+
return self
319355

320356
def __enter__(self) -> te.Self:
321357
if current_run_span.get() is not None:
322358
raise RuntimeError("You cannot start a run span within another run")
323359

360+
if self._remote_context is not None:
361+
otel_context = propagate.extract(carrier=self._remote_context)
362+
self._remote_token = context_api.attach(otel_context)
363+
else:
364+
super().__enter__()
365+
324366
self._context_token = current_run_span.set(self)
325-
span = super().__enter__()
326367
self.push_update(force=True)
327-
return span
368+
369+
return self
328370

329371
def __exit__(
330372
self,
331373
exc_type: type[BaseException] | None,
332374
exc_value: BaseException | None,
333375
traceback: types.TracebackType | None,
334376
) -> None:
377+
if self._remote_context is not None:
378+
super().__enter__() # Now we can open our actually span
379+
335380
# When we finally close out the final span, include all the
336381
# full data attributes, so we can skip the update spans during
337382
# db queries later.
@@ -355,6 +400,10 @@ def __exit__(
355400
)
356401

357402
super().__exit__(exc_type, exc_value, traceback)
403+
404+
if self._remote_token is not None:
405+
context_api.detach(self._remote_token) # type: ignore [arg-type]
406+
358407
if self._context_token is not None:
359408
current_run_span.reset(self._context_token)
360409

@@ -416,7 +465,7 @@ def log_object(
416465

417466
# Create a composite key that represents both data and schema
418467
hash_input = f"{data_hash}:{schema_hash}"
419-
composite_hash = hashlib.sha1(hash_input.encode()).hexdigest()[:16] # noqa: S324
468+
composite_hash = hashlib.sha1(hash_input.encode()).hexdigest()[:16] # noqa: S324 # nosec
420469

421470
# Store schema if new
422471
if schema_hash not in self._object_schemas:

0 commit comments

Comments
 (0)