Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions dbos/_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def __init__(
self.conductor_key: Optional[str] = conductor_key
if config.get("conductor_key"):
self.conductor_key = config.get("conductor_key")
self.enable_patching = config.get("enable_patching") == True
self.conductor_websocket: Optional[ConductorWebsocket] = None
self._background_event_loop: BackgroundEventLoop = BackgroundEventLoop()
self._active_workflows_set: set[str] = set()
Expand All @@ -350,6 +351,8 @@ def __init__(
# Globally set the application version and executor ID.
# In DBOS Cloud, instead use the values supplied through environment variables.
if not os.environ.get("DBOS__CLOUD") == "true":
if self.enable_patching:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this config. Because users should be able to use patching + versioning together. This code forces the version to be a weird PATCHING_ENABLED string which is not intuitive.

Copy link
Member Author

@kraftp kraftp Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes behavior much more intuitive. Patching requires a static version string. This check provides clean errors if you don't have one. You can still override it if you really want to. Otherwise you'll be thinking you're using patching, but your workflows don't recover and you don't know why.

GlobalParams.app_version = "PATCHING_ENABLED"
if (
"application_version" in config
and config["application_version"] is not None
Expand Down Expand Up @@ -1525,6 +1528,50 @@ async def read_stream_async(
await asyncio.sleep(1.0)
continue

@classmethod
def patch(cls, patch_name: str) -> bool:
if not _get_dbos_instance().enable_patching:
raise DBOSException("enable_patching must be True in DBOS configuration")
ctx = get_local_dbos_context()
if ctx is None or not ctx.is_workflow():
raise DBOSException("DBOS.patch must be called from a workflow")
workflow_id = ctx.workflow_id
function_id = ctx.function_id
patch_name = f"DBOS.patch-{patch_name}"
patched = _get_dbos_instance()._sys_db.patch(
workflow_id=workflow_id, function_id=function_id + 1, patch_name=patch_name
)
# If the patch was applied, increment function ID
if patched:
ctx.function_id += 1
return patched

@classmethod
def patch_async(cls, patch_name: str) -> Coroutine[Any, Any, bool]:
return asyncio.to_thread(cls.patch, patch_name)

@classmethod
def deprecate_patch(cls, patch_name: str) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with this approach is it's confusing where to run deprecate_patch: does it replace the original if-else statements? What if I have consecutive if-else branches?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better design is to explicitly register patches and explicitly enable/disable them. It's easier for testing too -- you can test the patched code controlling which patch is effective.

if not _get_dbos_instance().enable_patching:
raise DBOSException("enable_patching must be True in DBOS configuration")
ctx = get_local_dbos_context()
if ctx is None or not ctx.is_workflow():
raise DBOSException("DBOS.deprecate_patch must be called from a workflow")
workflow_id = ctx.workflow_id
function_id = ctx.function_id
patch_name = f"DBOS.patch-{patch_name}"
patch_exists = _get_dbos_instance()._sys_db.deprecate_patch(
workflow_id=workflow_id, function_id=function_id + 1, patch_name=patch_name
)
# If the patch is already in history, increment function ID
if patch_exists:
ctx.function_id += 1
return True

@classmethod
def deprecate_patch_async(cls, patch_name: str) -> Coroutine[Any, Any, bool]:
return asyncio.to_thread(cls.deprecate_patch, patch_name)

@classproperty
def tracer(self) -> DBOSTracer:
"""Return the DBOS OpenTelemetry tracer."""
Expand Down
1 change: 1 addition & 0 deletions dbos/_dbos_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class DBOSConfig(TypedDict, total=False):
conductor_key: Optional[str]
conductor_url: Optional[str]
serializer: Optional[Serializer]
enable_patching: Optional[bool]


class RuntimeConfig(TypedDict, total=False):
Expand Down
16 changes: 12 additions & 4 deletions dbos/_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(self, msg: str):
self.status_code = 403

def __reduce__(self) -> Any:
# Tell jsonpickle how to reconstruct this object
# Tell pickle how to reconstruct this object
return (self.__class__, (self.msg,))


Expand All @@ -162,7 +162,7 @@ def __init__(
)

def __reduce__(self) -> Any:
# Tell jsonpickle how to reconstruct this object
# Tell pickle how to reconstruct this object
return (self.__class__, (self.step_name, self.max_retries, self.errors))


Expand All @@ -182,11 +182,19 @@ class DBOSUnexpectedStepError(DBOSException):
def __init__(
self, workflow_id: str, step_id: int, expected_name: str, recorded_name: str
) -> None:
self.inputs = (workflow_id, step_id, expected_name, recorded_name)
super().__init__(
f"During execution of workflow {workflow_id} step {step_id}, function {recorded_name} was recorded when {expected_name} was expected. Check that your workflow is deterministic.",
dbos_error_code=DBOSErrorCode.UnexpectedStep.value,
)

def __reduce__(self) -> Any:
# Tell pickle how to reconstruct this object
return (
self.__class__,
self.inputs,
)


class DBOSQueueDeduplicatedError(DBOSException):
"""Exception raised when a workflow is deduplicated in the queue."""
Expand All @@ -203,7 +211,7 @@ def __init__(
)

def __reduce__(self) -> Any:
# Tell jsonpickle how to reconstruct this object
# Tell pickle how to reconstruct this object
return (
self.__class__,
(self.workflow_id, self.queue_name, self.deduplication_id),
Expand All @@ -219,7 +227,7 @@ def __init__(self, workflow_id: str):
)

def __reduce__(self) -> Any:
# Tell jsonpickle how to reconstruct this object
# Tell pickle how to reconstruct this object
return (self.__class__, (self.workflow_id,))


Expand Down
40 changes: 40 additions & 0 deletions dbos/_sys_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2338,3 +2338,43 @@ def get_metrics(self, start_time: str, end_time: str) -> List[MetricData]:
)

return metrics

@db_retry()
def patch(self, *, workflow_id: str, function_id: int, patch_name: str) -> bool:
"""If there is no checkpoint for this point in history,
insert a patch marker and return True.
Otherwise, return whether the checkpoint is this patch marker."""
with self.engine.begin() as c:
checkpoint_name: str | None = c.execute(
sa.select(SystemSchema.operation_outputs.c.function_name).where(
(SystemSchema.operation_outputs.c.workflow_uuid == workflow_id)
& (SystemSchema.operation_outputs.c.function_id == function_id)
)
).scalar()
if checkpoint_name is None:
result: OperationResultInternal = {
"workflow_uuid": workflow_id,
"function_id": function_id,
"function_name": patch_name,
"output": None,
"error": None,
"started_at_epoch_ms": int(time.time() * 1000),
}
self._record_operation_result_txn(result, c)
return True
else:
return checkpoint_name == patch_name

@db_retry()
def deprecate_patch(
self, *, workflow_id: str, function_id: int, patch_name: str
) -> bool:
"""Respect patch markers in history, but do not introduce new patch markers"""
with self.engine.begin() as c:
checkpoint_name: str | None = c.execute(
sa.select(SystemSchema.operation_outputs.c.function_name).where(
(SystemSchema.operation_outputs.c.workflow_uuid == workflow_id)
& (SystemSchema.operation_outputs.c.function_id == function_id)
)
).scalar()
return checkpoint_name == patch_name
2 changes: 1 addition & 1 deletion tests/test_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,7 +2078,7 @@ class JsonSerializer(Serializer):
def serialize(self, data: Any) -> str:
return json.dumps(data)

def deserialize(cls, serialized_data: str) -> Any:
def deserialize(self, serialized_data: str) -> Any:
return json.loads(serialized_data)

# Configure DBOS with a JSON-based custom serializer
Expand Down
Loading