Skip to content

Commit 2aa453a

Browse files
authored
Temp Workflows Shouldn't Create Spans (#169)
1 parent 13bb1d6 commit 2aa453a

File tree

2 files changed

+96
-6
lines changed

2 files changed

+96
-6
lines changed

dbos/_context.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,27 @@ def assign_workflow_id(self) -> str:
9898
wfid = str(uuid.uuid4())
9999
return wfid
100100

101-
def start_workflow(self, wfid: Optional[str], attributes: TracedAttributes) -> None:
101+
def start_workflow(
102+
self,
103+
wfid: Optional[str],
104+
attributes: TracedAttributes,
105+
is_temp_workflow: bool = False,
106+
) -> None:
102107
if wfid is None or len(wfid) == 0:
103108
wfid = self.assign_workflow_id()
104109
self.id_assigned_for_next_workflow = ""
105110
self.workflow_id = wfid
106111
self.function_id = 0
107-
self._start_span(attributes)
112+
if not is_temp_workflow:
113+
self._start_span(attributes)
108114

109-
def end_workflow(self, exc_value: Optional[BaseException]) -> None:
115+
def end_workflow(
116+
self, exc_value: Optional[BaseException], is_temp_workflow: bool = False
117+
) -> None:
110118
self.workflow_id = ""
111119
self.function_id = -1
112-
self._end_span(exc_value)
120+
if not is_temp_workflow:
121+
self._end_span(exc_value)
113122

114123
def is_within_workflow(self) -> bool:
115124
return len(self.workflow_id) > 0
@@ -349,6 +358,7 @@ class EnterDBOSWorkflow(AbstractContextManager[DBOSContext, Literal[False]]):
349358
def __init__(self, attributes: TracedAttributes) -> None:
350359
self.created_ctx = False
351360
self.attributes = attributes
361+
self.is_temp_workflow = attributes["name"] == "temp_wf"
352362

353363
def __enter__(self) -> DBOSContext:
354364
# Code to create a basic context
@@ -359,7 +369,7 @@ def __enter__(self) -> DBOSContext:
359369
_set_local_dbos_context(ctx)
360370
assert not ctx.is_within_workflow()
361371
ctx.start_workflow(
362-
None, self.attributes
372+
None, self.attributes, self.is_temp_workflow
363373
) # Will get from the context's next workflow ID
364374
return ctx
365375

@@ -371,7 +381,7 @@ def __exit__(
371381
) -> Literal[False]:
372382
ctx = assert_current_dbos_context()
373383
assert ctx.is_within_workflow()
374-
ctx.end_workflow(exc_value)
384+
ctx.end_workflow(exc_value, self.is_temp_workflow)
375385
# Code to clean up the basic context if we created it
376386
if self.created_ctx:
377387
_clear_local_dbos_context()

tests/test_spans.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Tuple
2+
3+
from fastapi import FastAPI
4+
from fastapi.testclient import TestClient
5+
from opentelemetry.sdk import trace as tracesdk
6+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
7+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
8+
9+
from dbos import DBOS
10+
from dbos._tracer import dbos_tracer
11+
12+
13+
def test_spans(dbos: DBOS) -> None:
14+
15+
@DBOS.workflow()
16+
def test_workflow() -> None:
17+
test_step()
18+
19+
@DBOS.step()
20+
def test_step() -> None:
21+
return
22+
23+
exporter = InMemorySpanExporter()
24+
span_processor = SimpleSpanProcessor(exporter)
25+
provider = tracesdk.TracerProvider()
26+
provider.add_span_processor(span_processor)
27+
dbos_tracer.set_provider(provider)
28+
29+
test_workflow()
30+
test_step()
31+
32+
spans = exporter.get_finished_spans()
33+
34+
assert len(spans) == 3
35+
36+
for span in spans:
37+
assert span.attributes is not None
38+
assert span.context is not None
39+
40+
assert spans[0].name == test_step.__name__
41+
assert spans[1].name == test_workflow.__name__
42+
assert spans[2].name == test_step.__name__
43+
44+
assert spans[0].parent.span_id == spans[1].context.span_id # type: ignore
45+
assert spans[1].parent == None
46+
assert spans[2].parent == None
47+
48+
49+
def test_temp_wf_fastapi(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None:
50+
_, app = dbos_fastapi
51+
52+
@app.get("/step")
53+
@DBOS.step()
54+
def test_step_endpoint() -> str:
55+
return "test"
56+
57+
exporter = InMemorySpanExporter()
58+
span_processor = SimpleSpanProcessor(exporter)
59+
provider = tracesdk.TracerProvider()
60+
provider.add_span_processor(span_processor)
61+
dbos_tracer.set_provider(provider)
62+
63+
client = TestClient(app)
64+
response = client.get("/step")
65+
assert response.status_code == 200
66+
assert response.text == '"test"'
67+
68+
spans = exporter.get_finished_spans()
69+
70+
assert len(spans) == 2
71+
72+
for span in spans:
73+
assert span.attributes is not None
74+
assert span.context is not None
75+
76+
assert spans[0].name == test_step_endpoint.__name__
77+
assert spans[1].name == "/step"
78+
79+
assert spans[0].parent.span_id == spans[1].context.span_id # type:ignore
80+
assert spans[1].parent == None

0 commit comments

Comments
 (0)