Skip to content

Commit 16657f9

Browse files
authored
Don't set workflow ID if the header is not set (#188)
This PR fixes a minor issue where the FastAPI/Flask handler would falsely print `Multiple workflows started in the same SetWorkflowID block.` The solution is to not enter the `with SetWorkflowID` block if the HTTP request doesn't contain the header to set the WF ID. Added tests to make sure no warning logs for those tests.
1 parent 983a028 commit 16657f9

File tree

4 files changed

+52
-6
lines changed

4 files changed

+52
-6
lines changed

dbos/_fastapi.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ async def dbos_fastapi_middleware(
9494
with EnterDBOSHandler(attributes):
9595
ctx = assert_current_dbos_context()
9696
ctx.request = _make_request(request)
97-
workflow_id = request.headers.get("dbos-idempotency-key", "")
98-
with SetWorkflowID(workflow_id):
97+
workflow_id = request.headers.get("dbos-idempotency-key")
98+
if workflow_id is not None:
99+
# Set the workflow ID for the handler
100+
with SetWorkflowID(workflow_id):
101+
response = await call_next(request)
102+
else:
99103
response = await call_next(request)
100104
return response

dbos/_flask.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@ def __call__(self, environ: Any, start_response: Any) -> Any:
3434
with EnterDBOSHandler(attributes):
3535
ctx = assert_current_dbos_context()
3636
ctx.request = _make_request(request)
37-
workflow_id = request.headers.get("dbos-idempotency-key", "")
38-
with SetWorkflowID(workflow_id):
37+
workflow_id = request.headers.get("dbos-idempotency-key")
38+
if workflow_id is not None:
39+
# Set the workflow ID for the handler
40+
with SetWorkflowID(workflow_id):
41+
response = self.app(environ, start_response)
42+
else:
3943
response = self.app(environ, start_response)
4044
return response
4145

tests/test_fastapi.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import logging
12
import uuid
23
from typing import Tuple
34

5+
import pytest
46
import sqlalchemy as sa
57
from fastapi import FastAPI
68
from fastapi.testclient import TestClient
@@ -12,7 +14,9 @@
1214
from dbos._context import assert_current_dbos_context
1315

1416

15-
def test_simple_endpoint(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None:
17+
def test_simple_endpoint(
18+
dbos_fastapi: Tuple[DBOS, FastAPI], caplog: pytest.LogCaptureFixture
19+
) -> None:
1620
dbos, app = dbos_fastapi
1721
client = TestClient(app)
1822

@@ -32,6 +36,7 @@ def test_workflow(var1: str, var2: str) -> str:
3236
res2 = test_step(var2)
3337
return res1 + res2
3438

39+
@app.get("/transaction/{var}")
3540
@DBOS.transaction()
3641
def test_transaction(var: str) -> str:
3742
rows = DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall()
@@ -41,13 +46,27 @@ def test_transaction(var: str) -> str:
4146
def test_step(var: str) -> str:
4247
return var
4348

49+
original_propagate = logging.getLogger("dbos").propagate
50+
caplog.set_level(logging.WARNING, "dbos")
51+
logging.getLogger("dbos").propagate = True
52+
4453
response = client.get("/workflow/bob/bob")
4554
assert response.status_code == 200
4655
assert response.text == '"bob1bob"'
56+
assert caplog.text == ""
4757

4858
response = client.get("/endpoint/bob/bob")
4959
assert response.status_code == 200
5060
assert response.text == '"bob1bob"'
61+
assert caplog.text == ""
62+
63+
response = client.get("/transaction/bob")
64+
assert response.status_code == 200
65+
assert response.text == '"bob1"'
66+
assert caplog.text == ""
67+
68+
# Reset logging
69+
logging.getLogger("dbos").propagate = original_propagate
5170

5271

5372
def test_start_workflow(dbos_fastapi: Tuple[DBOS, FastAPI]) -> None:

tests/test_flask.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
import logging
12
import uuid
23
from typing import Tuple
34

5+
import pytest
46
import sqlalchemy as sa
57
from flask import Flask, Response, jsonify
68

79
from dbos import DBOS
810
from dbos._context import assert_current_dbos_context
911

1012

11-
def test_flask_endpoint(dbos_flask: Tuple[DBOS, Flask]) -> None:
13+
def test_flask_endpoint(
14+
dbos_flask: Tuple[DBOS, Flask], caplog: pytest.LogCaptureFixture
15+
) -> None:
1216
_, app = dbos_flask
1317

1418
@app.route("/endpoint/<var1>/<var2>")
@@ -27,6 +31,7 @@ def test_workflow(var1: str, var2: str) -> Response:
2731
result = res1 + res2
2832
return jsonify({"result": result})
2933

34+
@app.route("/transaction/<var>")
3035
@DBOS.transaction()
3136
def test_transaction(var: str) -> str:
3237
rows = DBOS.sql_session.execute(sa.text("SELECT 1")).fetchall()
@@ -39,13 +44,27 @@ def test_step(var: str) -> str:
3944
app.config["TESTING"] = True
4045
client = app.test_client()
4146

47+
original_propagate = logging.getLogger("dbos").propagate
48+
caplog.set_level(logging.WARNING, "dbos")
49+
logging.getLogger("dbos").propagate = True
50+
4251
response = client.get("/endpoint/a/b")
4352
assert response.status_code == 200
4453
assert response.json == {"result": "a1b"}
54+
assert caplog.text == ""
4555

4656
response = client.get("/workflow/a/b")
4757
assert response.status_code == 200
4858
assert response.json == {"result": "a1b"}
59+
assert caplog.text == ""
60+
61+
response = client.get("/transaction/bob")
62+
assert response.status_code == 200
63+
assert response.text == "bob1"
64+
assert caplog.text == ""
65+
66+
# Reset logging
67+
logging.getLogger("dbos").propagate = original_propagate
4968

5069

5170
def test_endpoint_recovery(dbos_flask: Tuple[DBOS, Flask]) -> None:

0 commit comments

Comments
 (0)