55
66import pytest
77import sqlalchemy as sa
8+ from opentelemetry ._logs import set_logger_provider
9+ from opentelemetry .sdk ._logs import LoggerProvider , LoggingHandler
10+ from opentelemetry .sdk ._logs .export import BatchLogRecordProcessor , InMemoryLogExporter
811
912# Public API
1013from dbos import (
1922from dbos ._dbos import WorkflowHandle
2023from dbos ._dbos_config import ConfigFile
2124from dbos ._error import DBOSAwaitedWorkflowCancelledError , DBOSException
25+ from dbos ._logger import dbos_logger
26+ from dbos ._registrations import get_dbos_func_name
2227
2328
2429@pytest .mark .asyncio
@@ -31,7 +36,7 @@ async def test_async_workflow(dbos: DBOS) -> None:
3136 async def test_workflow (var1 : str , var2 : str ) -> str :
3237 nonlocal wf_counter
3338 wf_counter += 1
34- res1 = test_transaction ( var1 )
39+ res1 = await asyncio . to_thread ( test_transaction , var1 )
3540 res2 = await test_step (var2 )
3641 DBOS .logger .info ("I'm test_workflow" )
3742 return res1 + res2
@@ -88,7 +93,7 @@ async def test_async_step(dbos: DBOS) -> None:
8893 async def test_workflow (var1 : str , var2 : str ) -> str :
8994 nonlocal wf_counter
9095 wf_counter += 1
91- res1 = test_transaction ( var1 )
96+ res1 = await asyncio . to_thread ( test_transaction , var1 )
9297 res2 = await test_step (var2 )
9398 DBOS .logger .info ("I'm test_workflow" )
9499 return res1 + res2
@@ -325,6 +330,7 @@ def test_async_tx_raises(config: ConfigFile) -> None:
325330 async def test_async_tx () -> None :
326331 pass
327332
333+ assert "is a coroutine function" in str (exc_info .value )
328334 # destroy call needed to avoid "functions were registered but DBOS() was not called" warning
329335 DBOS .destroy (destroy_registry = True )
330336
@@ -343,12 +349,12 @@ async def test_workflow(var1: str, var2: str) -> str:
343349 wf_el_id = id (asyncio .get_running_loop ())
344350 nonlocal wf_counter
345351 wf_counter += 1
346- res2 = test_step (var2 )
352+ res2 = await test_step (var2 )
347353 DBOS .logger .info ("I'm test_workflow" )
348354 return var1 + res2
349355
350356 @DBOS .step ()
351- def test_step (var : str ) -> str :
357+ async def test_step (var : str ) -> str :
352358 nonlocal step_el_id
353359 step_el_id = id (asyncio .get_running_loop ())
354360 nonlocal step_counter
@@ -605,3 +611,83 @@ async def run_workflow_task() -> str:
605611 # Verify the workflow completes despite the task cancellation
606612 handle : WorkflowHandleAsync [str ] = await DBOS .retrieve_workflow_async (wfid )
607613 assert await handle .get_result () == "completed"
614+
615+
616+ @pytest .mark .asyncio
617+ async def test_check_async_violation (dbos : DBOS ) -> None :
618+ # Set up in-memory log exporter
619+ log_exporter = InMemoryLogExporter () # type: ignore
620+ log_processor = BatchLogRecordProcessor (log_exporter )
621+ log_provider = LoggerProvider ()
622+ log_provider .add_log_record_processor (log_processor )
623+ set_logger_provider (log_provider )
624+ dbos_logger .addHandler (LoggingHandler (logger_provider = log_provider ))
625+
626+ @DBOS .workflow ()
627+ def sync_workflow () -> str :
628+ return "sync"
629+
630+ @DBOS .step ()
631+ def sync_step () -> str :
632+ return "step"
633+
634+ @DBOS .workflow ()
635+ async def async_workflow_sync_step () -> str :
636+ return sync_step ()
637+
638+ @DBOS .transaction ()
639+ def sync_transaction () -> str :
640+ return "txn"
641+
642+ @DBOS .workflow ()
643+ async def async_workflow_sync_txn () -> str :
644+ return sync_transaction ()
645+
646+ # Call a sync workflow should log a warning
647+ sync_workflow ()
648+
649+ log_processor .force_flush (timeout_millis = 5000 )
650+ logs = log_exporter .get_finished_logs ()
651+ assert len (logs ) == 1
652+ assert (
653+ logs [0 ].log_record .body is not None
654+ and f"Sync workflow ({ get_dbos_func_name (sync_workflow )} ) shouldn't be invoked from within another async function."
655+ in logs [0 ].log_record .body
656+ )
657+ log_exporter .clear ()
658+
659+ # Call a sync step from within an async workflow should log a warning
660+ await async_workflow_sync_step ()
661+ log_processor .force_flush (timeout_millis = 5000 )
662+ logs = log_exporter .get_finished_logs ()
663+ assert len (logs ) == 1
664+ assert (
665+ logs [0 ].log_record .body is not None
666+ and f"Sync step ({ get_dbos_func_name (sync_step )} ) shouldn't be invoked from within another async function."
667+ in logs [0 ].log_record .body
668+ )
669+ log_exporter .clear ()
670+
671+ # Directly call a sync step should log a warning
672+ sync_step ()
673+ log_processor .force_flush (timeout_millis = 5000 )
674+ logs = log_exporter .get_finished_logs ()
675+ assert len (logs ) == 1
676+ assert (
677+ logs [0 ].log_record .body is not None
678+ and f"Sync step ({ get_dbos_func_name (sync_step )} ) shouldn't be invoked from within another async function."
679+ in logs [0 ].log_record .body
680+ )
681+ log_exporter .clear ()
682+
683+ # Call a sync transaction from within an async workflow should log a warning
684+ await async_workflow_sync_txn ()
685+ log_processor .force_flush (timeout_millis = 5000 )
686+ logs = log_exporter .get_finished_logs ()
687+ assert len (logs ) == 1
688+ assert (
689+ logs [0 ].log_record .body is not None
690+ and f"Transaction function ({ get_dbos_func_name (sync_transaction )} ) shouldn't be invoked from within another async function."
691+ in logs [0 ].log_record .body
692+ )
693+ log_exporter .clear ()
0 commit comments