diff --git a/azure/durable_functions/decorators/durable_app.py b/azure/durable_functions/decorators/durable_app.py index 62b5b704..d13f8525 100644 --- a/azure/durable_functions/decorators/durable_app.py +++ b/azure/durable_functions/decorators/durable_app.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from .metadata import OrchestrationTrigger, ActivityTrigger, EntityTrigger,\ +from .metadata import OrchestrationTrigger, ActivityTrigger, EntityTrigger, \ DurableClient from typing import Callable, Optional from azure.durable_functions.entity import Entity diff --git a/azure/durable_functions/models/TaskOrchestrationExecutor.py b/azure/durable_functions/models/TaskOrchestrationExecutor.py index 6bb1a049..42e00c69 100644 --- a/azure/durable_functions/models/TaskOrchestrationExecutor.py +++ b/azure/durable_functions/models/TaskOrchestrationExecutor.py @@ -9,7 +9,7 @@ from collections import namedtuple import json from ..models.entities.ResponseMessage import ResponseMessage -from azure.functions._durable_functions import _deserialize_custom_object +from azure.functions._durable_functions import _deserialize_custom_object, _serialize_custom_object class TaskOrchestrationExecutor: @@ -276,12 +276,12 @@ def get_orchestrator_state_str(self) -> str: message contains in it the string representation of the orchestration's state """ - if(self.output is not None): + if (self.output is not None): try: # Attempt to serialize the output. If serialization fails, raise an # error indicating that the orchestration output is not serializable, # which is not permitted in durable Python functions. - json.dumps(self.output) + json.dumps(self.output, default=_serialize_custom_object) except Exception as e: self.output = None self.exception = e diff --git a/tests/orchestrator/orchestrator_test_utils.py b/tests/orchestrator/orchestrator_test_utils.py index b4ca6fcf..e7f76da4 100644 --- a/tests/orchestrator/orchestrator_test_utils.py +++ b/tests/orchestrator/orchestrator_test_utils.py @@ -1,5 +1,7 @@ import json from typing import Callable, Iterator, Any, Dict, List + +from azure.functions._durable_functions import _deserialize_custom_object from jsonschema import validate from azure.durable_functions.models import DurableOrchestrationContext, DurableEntityContext @@ -71,7 +73,7 @@ def get_orchestration_state_result( orchestrator = Orchestrator(user_code) result_of_handle = orchestrator.handle( DurableOrchestrationContext.from_json(context_as_string)) - result = json.loads(result_of_handle) + result = json.loads(result_of_handle, object_hook=_deserialize_custom_object) return result def get_entity_state_result( diff --git a/tests/orchestrator/test_serialization.py b/tests/orchestrator/test_serialization.py index d92966c1..edc68606 100644 --- a/tests/orchestrator/test_serialization.py +++ b/tests/orchestrator/test_serialization.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass + from azure.durable_functions.models.ReplaySchema import ReplaySchema from tests.test_utils.ContextBuilder import ContextBuilder from .orchestrator_test_utils \ @@ -28,4 +30,40 @@ def test_serialization_of_False(): expected["output"] = False assert_valid_schema(result) - assert_orchestration_state_equals(expected, result) \ No newline at end of file + assert_orchestration_state_equals(expected, result) + + +@dataclass +class CustomResult(): + message: str + + def to_json(self): + return {"message": self.message} + + @classmethod + def from_json(cls, data): + return cls(message=data["message"]) + +def generator_function_with_custom_class(context): + return CustomResult(message="Custom serialization test") + +def test_serialization_of_custom_class(): + """Test that an orchestrator can return False.""" + + context_builder = ContextBuilder("serialize custom class") + + result = get_orchestration_state_result( + context_builder, generator_function_with_custom_class) + + expected_output = CustomResult(message="Custom serialization test") + expected_state = base_expected_state(output=expected_output) + + expected_state._is_done = True + expected = expected_state.to_json() + + # Since we're essentially testing the `to_json` functionality, + # we explicitely ensure that the output is set + expected["output"] = expected_output + + assert_valid_schema(result) + assert_orchestration_state_equals(expected, result)