diff --git a/cadence/_internal/activity/_activity_executor.py b/cadence/_internal/activity/_activity_executor.py index e37e736..56f717d 100644 --- a/cadence/_internal/activity/_activity_executor.py +++ b/cadence/_internal/activity/_activity_executor.py @@ -57,7 +57,7 @@ async def _report_failure(self, task: PollForActivityTaskResponse, error: Except _logger.exception('Exception reporting activity failure') async def _report_success(self, task: PollForActivityTaskResponse, result: Any): - as_payload = await self._data_converter.to_data(result) + as_payload = await self._data_converter.to_data([result]) try: await self._client.worker_stub.RespondActivityTaskCompleted(RespondActivityTaskCompletedRequest( diff --git a/cadence/data_converter.py b/cadence/data_converter.py index e88680f..350848e 100644 --- a/cadence/data_converter.py +++ b/cadence/data_converter.py @@ -5,6 +5,7 @@ from json import JSONDecoder from msgspec import json, convert +_SPACE = ' '.encode() class DataConverter(Protocol): @@ -19,33 +20,24 @@ async def to_data(self, values: List[Any]) -> Payload: class DefaultDataConverter(DataConverter): def __init__(self) -> None: self._encoder = json.Encoder() - self._decoder = json.Decoder() - self._fallback_decoder = JSONDecoder(strict=False) + # Need to use std lib decoder in order to decode the custom whitespace delimited data format + self._decoder = JSONDecoder(strict=False) async def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]: if not payload.data: return DefaultDataConverter._convert_into([], type_hints) - if len(type_hints) > 1: - payload_str = payload.data.decode() - # Handle payloads from the Go client, which are a series of json objects rather than a json array - if not payload_str.startswith("["): - return self._decode_whitespace_delimited(payload_str, type_hints) - else: - as_list = self._decoder.decode(payload_str) - return DefaultDataConverter._convert_into(as_list, type_hints) - - as_value = self._decoder.decode(payload.data) - return DefaultDataConverter._convert_into([as_value], type_hints) + payload_str = payload.data.decode() + return self._decode_whitespace_delimited(payload_str, type_hints) def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type | None]) -> List[Any]: results: List[Any] = [] start, end = 0, len(payload) while start < end and len(results) < len(type_hints): remaining = payload[start:end] - (value, value_end) = self._fallback_decoder.raw_decode(remaining) + (value, value_end) = self._decoder.raw_decode(remaining) start += value_end + 1 results.append(value) @@ -76,10 +68,11 @@ def _get_default(type_hint: Type) -> Any: async def to_data(self, values: List[Any]) -> Payload: - data_value = values - # Don't wrap single values in a json array - if len(values) == 1: - data_value = values[0] + result = bytearray() + for index, value in enumerate(values): + self._encoder.encode_into(value, result, -1) + if index < len(values) - 1: + result += _SPACE - return Payload(data=self._encoder.encode(data_value)) + return Payload(data=bytes(result)) diff --git a/tests/cadence/_internal/activity/test_activity_executor.py b/tests/cadence/_internal/activity/test_activity_executor.py index d6aba4d..abc02cb 100644 --- a/tests/cadence/_internal/activity/test_activity_executor.py +++ b/tests/cadence/_internal/activity/test_activity_executor.py @@ -82,7 +82,7 @@ async def activity_fn(first: str, second: str): executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("activity_type", '["hello", "world"]')) + await executor.execute(fake_task("activity_type", '"hello" "world"')) worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( task_token=b'task_token', diff --git a/tests/cadence/data_converter_test.py b/tests/cadence/data_converter_test.py index cbc7ba6..068f3a5 100644 --- a/tests/cadence/data_converter_test.py +++ b/tests/cadence/data_converter_test.py @@ -23,31 +23,28 @@ class _TestDataClass: '"Hello" "world"', [str, str], ["Hello", "world"], id="space delimited" ), pytest.param( - '["Hello", "world"]', [str, str], ["Hello", "world"], id="json array" + "1", [int, int], [1, 0], id="ints" ), pytest.param( - "[1]", [int, int], [1, 0], id="ints" + "1.5", [float, float], [1.5, 0.0], id="floats" ), pytest.param( - "[1.5]", [float, float], [1.5, 0.0], id="floats" + "true", [bool, bool], [True, False], id="bools" ), pytest.param( - "[true]", [bool, bool], [True, False], id="bools" + '{"foo": "hello world", "bar": 42, "baz": {"bar": 43}}', [_TestDataClass, _TestDataClass], [_TestDataClass("hello world", 42, _TestDataClass(bar=43)), None], id="data classes" ), pytest.param( - '[{"foo": "hello world", "bar": 42, "baz": {"bar": 43}}]', [_TestDataClass, _TestDataClass], [_TestDataClass("hello world", 42, _TestDataClass(bar=43)), None], id="data classes" + '{"foo": "hello world"}', [dict, dict], [{"foo": "hello world"}, None], id="dicts" ), pytest.param( - '[{"foo": "hello world"}]', [dict, dict], [{"foo": "hello world"}, None], id="dicts" + '{"foo": 52}', [dict[str, int], dict], [{"foo": 52}, None], id="generic dicts" ), pytest.param( - '[{"foo": 52}]', [dict[str, int], dict], [{"foo": 52}, None], id="generic dicts" + '["hello"]', [list[str], list[str]], [["hello"], None], id="lists" ), pytest.param( - '[["hello"]]', [list[str], list[str]], [["hello"], None], id="lists" - ), - pytest.param( - '[["hello"]]', [set[str], set[str]], [{"hello"}, None], id="sets" + '["hello"]', [set[str], set[str]], [{"hello"}, None], id="sets" ), pytest.param( '["hello", "world"]', [list[str]], [["hello", "world"]], id="list" @@ -56,10 +53,6 @@ class _TestDataClass: '{"foo": "bar"} {"bar": 100} ["hello"] "world"', [_TestDataClass, _TestDataClass, list[str], str], [_TestDataClass(foo="bar"), _TestDataClass(bar=100), ["hello"], "world"], id="space delimited mix" ), - pytest.param( - '[{"foo": "bar"},{"bar": 100},["hello"],"world"]', [_TestDataClass, _TestDataClass, list[str], str], - [_TestDataClass(foo="bar"), _TestDataClass(bar=100), ["hello"], "world"], id="json array mix" - ), pytest.param( "", [], [], id="no input expected" ), @@ -67,7 +60,7 @@ class _TestDataClass: "", [str], [None], id="no input unexpected" ), pytest.param( - '["hello world", {"foo":"bar"}, 7]', [None, None, None], ["hello world", {"foo":"bar"}, 7], id="no type hints" + '"hello world" {"foo":"bar"} 7', [None, None, None], ["hello world", {"foo":"bar"}, 7], id="no type hints" ), pytest.param( '"hello" "world" "goodbye"', [str, str], ["hello", "world"], @@ -75,7 +68,6 @@ class _TestDataClass: ), ] ) -@pytest.mark.asyncio async def test_data_converter_from_data(json: str, types: list[Type], expected: list[Any]) -> None: converter = DefaultDataConverter() actual = await converter.from_data(Payload(data=json.encode()), types) @@ -88,18 +80,30 @@ async def test_data_converter_from_data(json: str, types: list[Type], expected: ["hello world"], '"hello world"', id="happy path" ), pytest.param( - ["hello", "world"], '["hello", "world"]', id="multiple values" + ["hello", "world"], '"hello" "world"', id="multiple values" + ), + pytest.param( + [[["hello"]], ["world"]], '[["hello"]] ["world"]', id="lists" + ), + pytest.param( + [1, 2, 10], '1 2 10', id="numeric values" + ), + pytest.param( + [True, False], 'true false', id="bool values" + ), + pytest.param( + [{'foo': 'foo', 'bar': 20}], '{"bar":20,"foo":"foo"}', id="dict values" + ), + pytest.param( + [{'foo', 'bar'}], '["bar","foo"]', id="set values" ), pytest.param( - [_TestDataClass()], '{"foo": "foo", "bar": -1, "baz": null}', id="data classes" + [_TestDataClass()], '{"foo":"foo","bar":-1,"baz":null}', id="data classes" ), ] ) -@pytest.mark.asyncio async def test_data_converter_to_data(values: list[Any], expected: str) -> None: converter = DefaultDataConverter() + converter._encoder = json.Encoder(order='deterministic') actual = await converter.to_data(values) - # Parse both rather than trying to compare strings - actual_parsed = json.decode(actual.data) - expected_parsed = json.decode(expected) - assert expected_parsed == actual_parsed \ No newline at end of file + assert actual.data.decode() == expected \ No newline at end of file