diff --git a/dbos/_dbos.py b/dbos/_dbos.py index d0dc4394..41dd06db 100644 --- a/dbos/_dbos.py +++ b/dbos/_dbos.py @@ -341,6 +341,7 @@ def __init__( self.conductor_key: Optional[str] = conductor_key if config.get("conductor_key"): self.conductor_key = config.get("conductor_key") + self.enable_patching = config.get("enable_patching") == True self.conductor_websocket: Optional[ConductorWebsocket] = None self._background_event_loop: BackgroundEventLoop = BackgroundEventLoop() self._active_workflows_set: set[str] = set() @@ -350,6 +351,8 @@ def __init__( # Globally set the application version and executor ID. # In DBOS Cloud, instead use the values supplied through environment variables. if not os.environ.get("DBOS__CLOUD") == "true": + if self.enable_patching: + GlobalParams.app_version = "PATCHING_ENABLED" if ( "application_version" in config and config["application_version"] is not None @@ -1525,6 +1528,50 @@ async def read_stream_async( await asyncio.sleep(1.0) continue + @classmethod + def patch(cls, patch_name: str) -> bool: + if not _get_dbos_instance().enable_patching: + raise DBOSException("enable_patching must be True in DBOS configuration") + ctx = get_local_dbos_context() + if ctx is None or not ctx.is_workflow(): + raise DBOSException("DBOS.patch must be called from a workflow") + workflow_id = ctx.workflow_id + function_id = ctx.function_id + patch_name = f"DBOS.patch-{patch_name}" + patched = _get_dbos_instance()._sys_db.patch( + workflow_id=workflow_id, function_id=function_id + 1, patch_name=patch_name + ) + # If the patch was applied, increment function ID + if patched: + ctx.function_id += 1 + return patched + + @classmethod + def patch_async(cls, patch_name: str) -> Coroutine[Any, Any, bool]: + return asyncio.to_thread(cls.patch, patch_name) + + @classmethod + def deprecate_patch(cls, patch_name: str) -> bool: + if not _get_dbos_instance().enable_patching: + raise DBOSException("enable_patching must be True in DBOS configuration") + ctx = get_local_dbos_context() + if ctx is None or not ctx.is_workflow(): + raise DBOSException("DBOS.deprecate_patch must be called from a workflow") + workflow_id = ctx.workflow_id + function_id = ctx.function_id + patch_name = f"DBOS.patch-{patch_name}" + patch_exists = _get_dbos_instance()._sys_db.deprecate_patch( + workflow_id=workflow_id, function_id=function_id + 1, patch_name=patch_name + ) + # If the patch is already in history, increment function ID + if patch_exists: + ctx.function_id += 1 + return True + + @classmethod + def deprecate_patch_async(cls, patch_name: str) -> Coroutine[Any, Any, bool]: + return asyncio.to_thread(cls.deprecate_patch, patch_name) + @classproperty def tracer(self) -> DBOSTracer: """Return the DBOS OpenTelemetry tracer.""" diff --git a/dbos/_dbos_config.py b/dbos/_dbos_config.py index 85ff7552..59fa0fc4 100644 --- a/dbos/_dbos_config.py +++ b/dbos/_dbos_config.py @@ -63,6 +63,7 @@ class DBOSConfig(TypedDict, total=False): conductor_key: Optional[str] conductor_url: Optional[str] serializer: Optional[Serializer] + enable_patching: Optional[bool] class RuntimeConfig(TypedDict, total=False): diff --git a/dbos/_error.py b/dbos/_error.py index 6c11bc04..3f9558b4 100644 --- a/dbos/_error.py +++ b/dbos/_error.py @@ -143,7 +143,7 @@ def __init__(self, msg: str): self.status_code = 403 def __reduce__(self) -> Any: - # Tell jsonpickle how to reconstruct this object + # Tell pickle how to reconstruct this object return (self.__class__, (self.msg,)) @@ -162,7 +162,7 @@ def __init__( ) def __reduce__(self) -> Any: - # Tell jsonpickle how to reconstruct this object + # Tell pickle how to reconstruct this object return (self.__class__, (self.step_name, self.max_retries, self.errors)) @@ -182,11 +182,19 @@ class DBOSUnexpectedStepError(DBOSException): def __init__( self, workflow_id: str, step_id: int, expected_name: str, recorded_name: str ) -> None: + self.inputs = (workflow_id, step_id, expected_name, recorded_name) super().__init__( f"During execution of workflow {workflow_id} step {step_id}, function {recorded_name} was recorded when {expected_name} was expected. Check that your workflow is deterministic.", dbos_error_code=DBOSErrorCode.UnexpectedStep.value, ) + def __reduce__(self) -> Any: + # Tell pickle how to reconstruct this object + return ( + self.__class__, + self.inputs, + ) + class DBOSQueueDeduplicatedError(DBOSException): """Exception raised when a workflow is deduplicated in the queue.""" @@ -203,7 +211,7 @@ def __init__( ) def __reduce__(self) -> Any: - # Tell jsonpickle how to reconstruct this object + # Tell pickle how to reconstruct this object return ( self.__class__, (self.workflow_id, self.queue_name, self.deduplication_id), @@ -219,7 +227,7 @@ def __init__(self, workflow_id: str): ) def __reduce__(self) -> Any: - # Tell jsonpickle how to reconstruct this object + # Tell pickle how to reconstruct this object return (self.__class__, (self.workflow_id,)) diff --git a/dbos/_sys_db.py b/dbos/_sys_db.py index 860fa7f7..6b9f22cd 100644 --- a/dbos/_sys_db.py +++ b/dbos/_sys_db.py @@ -2338,3 +2338,43 @@ def get_metrics(self, start_time: str, end_time: str) -> List[MetricData]: ) return metrics + + @db_retry() + def patch(self, *, workflow_id: str, function_id: int, patch_name: str) -> bool: + """If there is no checkpoint for this point in history, + insert a patch marker and return True. + Otherwise, return whether the checkpoint is this patch marker.""" + with self.engine.begin() as c: + checkpoint_name: str | None = c.execute( + sa.select(SystemSchema.operation_outputs.c.function_name).where( + (SystemSchema.operation_outputs.c.workflow_uuid == workflow_id) + & (SystemSchema.operation_outputs.c.function_id == function_id) + ) + ).scalar() + if checkpoint_name is None: + result: OperationResultInternal = { + "workflow_uuid": workflow_id, + "function_id": function_id, + "function_name": patch_name, + "output": None, + "error": None, + "started_at_epoch_ms": int(time.time() * 1000), + } + self._record_operation_result_txn(result, c) + return True + else: + return checkpoint_name == patch_name + + @db_retry() + def deprecate_patch( + self, *, workflow_id: str, function_id: int, patch_name: str + ) -> bool: + """Respect patch markers in history, but do not introduce new patch markers""" + with self.engine.begin() as c: + checkpoint_name: str | None = c.execute( + sa.select(SystemSchema.operation_outputs.c.function_name).where( + (SystemSchema.operation_outputs.c.workflow_uuid == workflow_id) + & (SystemSchema.operation_outputs.c.function_id == function_id) + ) + ).scalar() + return checkpoint_name == patch_name diff --git a/tests/test_dbos.py b/tests/test_dbos.py index 9ccb7b5f..335726c0 100644 --- a/tests/test_dbos.py +++ b/tests/test_dbos.py @@ -2078,7 +2078,7 @@ class JsonSerializer(Serializer): def serialize(self, data: Any) -> str: return json.dumps(data) - def deserialize(cls, serialized_data: str) -> Any: + def deserialize(self, serialized_data: str) -> Any: return json.loads(serialized_data) # Configure DBOS with a JSON-based custom serializer diff --git a/tests/test_patch.py b/tests/test_patch.py new file mode 100644 index 00000000..5d6490c4 --- /dev/null +++ b/tests/test_patch.py @@ -0,0 +1,312 @@ +# mypy: disable-error-code="no-redef" +import pytest + +from dbos import DBOS, DBOSConfig +from dbos._error import DBOSUnexpectedStepError +from dbos._utils import GlobalParams + + +def test_patch(dbos: DBOS, config: DBOSConfig) -> None: + DBOS.destroy(destroy_registry=True) + config["enable_patching"] = True + DBOS(config=config) + + @DBOS.step() + def step_one() -> int: + return 1 + + @DBOS.step() + def step_two() -> int: + return 2 + + @DBOS.step() + def step_three() -> int: + return 3 + + @DBOS.workflow() + def workflow() -> int: + a = step_one() + b = step_two() + return a + b + + DBOS.launch() + + # Register and run the first version of a workflow + handle = DBOS.start_workflow(workflow) + v1_id = handle.workflow_id + assert handle.get_result() == 3 + + # Recreate DBOS with a new (patched) version of a workflow + DBOS.destroy(destroy_registry=True) + DBOS(config=config) + + step_one = DBOS.step()(step_one) + step_two = DBOS.step()(step_two) + step_three = DBOS.step()(step_three) + + @DBOS.workflow() + def workflow() -> int: + if DBOS.patch("v2"): + a = step_three() + else: + a = step_one() + b = step_two() + return a + b + + DBOS.launch() + + # Verify a new execution runs the post-patch workflow + # and stores a patch marker + handle = DBOS.start_workflow(workflow) + v2_id = handle.workflow_id + assert handle.get_status().app_version == "PATCHING_ENABLED" + assert handle.get_result() == 5 + steps = DBOS.list_workflow_steps(handle.workflow_id) + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 3 + assert steps[0]["function_name"] == "DBOS.patch-v2" + + # Verify an execution containing the patch marker + # can recover past the patch marker + handle = DBOS.fork_workflow(v2_id, 3) + assert handle.get_result() == 5 + steps = DBOS.list_workflow_steps(handle.workflow_id) + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 3 + assert steps[0]["function_name"] == "DBOS.patch-v2" + + # Verify an old execution runs the pre-patch workflow + # and does not store a patch marker + handle = DBOS.fork_workflow(v1_id, 2) + assert handle.get_result() == 3 + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 2 + + # Recreate DBOS with a another (patched) version of a workflow + DBOS.destroy(destroy_registry=True) + DBOS(config=config) + + step_one = DBOS.step()(step_one) + step_two = DBOS.step()(step_two) + step_three = DBOS.step()(step_three) + + @DBOS.workflow() + def workflow() -> int: + if DBOS.patch("v3"): + a = step_two() + elif DBOS.patch("v2"): + a = step_three() + else: + a = step_one() + b = step_two() + return a + b + + DBOS.launch() + + # Verify a new execution runs the post-patch workflow + # and stores a patch marker + handle = DBOS.start_workflow(workflow) + v3_id = handle.workflow_id + assert handle.get_result() == 4 + steps = DBOS.list_workflow_steps(handle.workflow_id) + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 3 + assert steps[0]["function_name"] == "DBOS.patch-v3" + + # Verify an execution containing the v3 patch marker + # recovers to v3 + handle = DBOS.fork_workflow(v3_id, 3) + assert handle.get_result() == 4 + assert handle.get_status().app_version == "PATCHING_ENABLED" + steps = DBOS.list_workflow_steps(handle.workflow_id) + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 3 + assert steps[0]["function_name"] == "DBOS.patch-v3" + + # Verify an execution containing the v2 patch marker + # recovers to v2 + handle = DBOS.fork_workflow(v2_id, 3) + assert handle.get_result() == 5 + steps = DBOS.list_workflow_steps(handle.workflow_id) + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 3 + assert steps[0]["function_name"] == "DBOS.patch-v2" + + # Verify a v1 execution recovers the pre-patch workflow + # and does not store a patch marker + handle = DBOS.fork_workflow(v1_id, 2) + assert handle.get_result() == 3 + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 2 + + # Now, let's deprecate the patch + DBOS.destroy(destroy_registry=True) + DBOS(config=config) + + step_one = DBOS.step()(step_one) + step_two = DBOS.step()(step_two) + step_three = DBOS.step()(step_three) + + @DBOS.workflow() + def workflow() -> int: + DBOS.deprecate_patch("v3") + a = step_two() + b = step_two() + return a + b + + DBOS.launch() + + # Verify a new execution runs the final workflow + # but does not store a patch marker + handle = DBOS.start_workflow(workflow) + v4_id = handle.workflow_id + assert handle.get_result() == 4 + steps = DBOS.list_workflow_steps(handle.workflow_id) + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 2 + + # Verify an execution sans patch marker recovers correctly + handle = DBOS.fork_workflow(v4_id, 3) + assert handle.get_result() == 4 + steps = DBOS.list_workflow_steps(handle.workflow_id) + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 2 + + # Verify an execution containing the v3 patch marker + # recovers to v3 + handle = DBOS.fork_workflow(v3_id, 3) + assert handle.get_result() == 4 + steps = DBOS.list_workflow_steps(handle.workflow_id) + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 3 + assert steps[0]["function_name"] == "DBOS.patch-v3" + + # Verify an execution containing the v2 patch marker + # cleanly fails + handle = DBOS.fork_workflow(v2_id, 3) + with pytest.raises(DBOSUnexpectedStepError): + handle.get_result() + + # Verify a v1 execution cleanly fails + handle = DBOS.fork_workflow(v1_id, 2) + with pytest.raises(DBOSUnexpectedStepError): + handle.get_result() + + # Finally, let's remove the patch + DBOS.destroy(destroy_registry=True) + DBOS(config=config) + + step_one = DBOS.step()(step_one) + step_two = DBOS.step()(step_two) + step_three = DBOS.step()(step_three) + + @DBOS.workflow() + def workflow() -> int: + a = step_two() + b = step_two() + return a + b + + DBOS.launch() + + # Verify an execution from the deprecated patch works + # sans patch marker + handle = DBOS.fork_workflow(v4_id, 3) + assert handle.get_result() == 4 + steps = DBOS.list_workflow_steps(handle.workflow_id) + assert len(DBOS.list_workflow_steps(handle.workflow_id)) == 2 + + # Verify an execution containing the v3 patch marker + # cleanly fails + handle = DBOS.fork_workflow(v3_id, 3) + with pytest.raises(DBOSUnexpectedStepError): + handle.get_result() + + # Verify an execution containing the v2 patch marker + # cleanly fails + handle = DBOS.fork_workflow(v2_id, 3) + with pytest.raises(DBOSUnexpectedStepError): + handle.get_result() + + # Verify a v1 execution cleanly fails + handle = DBOS.fork_workflow(v1_id, 2) + with pytest.raises(DBOSUnexpectedStepError): + handle.get_result() + + +@pytest.mark.asyncio +async def test_patch_async(dbos: DBOS, config: DBOSConfig) -> None: + DBOS.destroy(destroy_registry=True) + config["enable_patching"] = True + DBOS(config=config) + + @DBOS.step() + async def step_one() -> int: + return 1 + + @DBOS.step() + async def step_two() -> int: + return 2 + + @DBOS.step() + async def step_three() -> int: + return 3 + + @DBOS.workflow() + async def workflow() -> int: + a = await step_one() + b = await step_two() + return a + b + + DBOS.launch() + + # Register and run the first version of a workflow + handle = await DBOS.start_workflow_async(workflow) + v1_id = handle.workflow_id + assert await handle.get_result() == 3 + + # Recreate DBOS with a new (patched) version of a workflow + DBOS.destroy(destroy_registry=True) + DBOS(config=config) + + step_one = DBOS.step()(step_one) + step_two = DBOS.step()(step_two) + step_three = DBOS.step()(step_three) + + @DBOS.workflow() + async def workflow() -> int: + if await DBOS.patch_async("v2"): + a = await step_three() + else: + a = await step_one() + b = await step_two() + return a + b + + DBOS.launch() + + # Verify a new execution runs the post-patch workflow + # and stores a patch marker + handle = await DBOS.start_workflow_async(workflow) + assert await handle.get_result() == 5 + steps = await DBOS.list_workflow_steps_async(handle.workflow_id) + assert len(await DBOS.list_workflow_steps_async(handle.workflow_id)) == 3 + assert steps[0]["function_name"] == "DBOS.patch-v2" + + # Verify an old execution runs the pre-patch workflow + # and does not store a patch marker + handle = await DBOS.fork_workflow_async(v1_id, 2) + assert await handle.get_result() == 3 + assert len(await DBOS.list_workflow_steps_async(handle.workflow_id)) == 2 + + # Now, let's deprecate the patch + DBOS.destroy(destroy_registry=True) + DBOS(config=config) + + step_one = DBOS.step()(step_one) + step_two = DBOS.step()(step_two) + step_three = DBOS.step()(step_three) + + @DBOS.workflow() + async def workflow() -> int: + await DBOS.deprecate_patch_async("v3") + a = await step_two() + b = await step_two() + return a + b + + DBOS.launch() + + # Verify a new execution runs the final workflow + # but does not store a patch marker + handle = await DBOS.start_workflow_async(workflow) + assert await handle.get_result() == 4 + steps = await DBOS.list_workflow_steps_async(handle.workflow_id) + assert len(await DBOS.list_workflow_steps_async(handle.workflow_id)) == 2