diff --git a/.changelog/_unreleased.toml b/.changelog/_unreleased.toml new file mode 100644 index 0000000..d5f3ad4 --- /dev/null +++ b/.changelog/_unreleased.toml @@ -0,0 +1,5 @@ +[[entries]] +id = "6d0f41f2-f7f9-4808-af65-196a7a909b4f" +type = "fix" +description = "Fix #47: Union with Literal in them can now de/serialize" +author = "@rhaps0dy" diff --git a/databind/src/databind/json/converters.py b/databind/src/databind/json/converters.py index e04df80..fbb1322 100644 --- a/databind/src/databind/json/converters.py +++ b/databind/src/databind/json/converters.py @@ -763,13 +763,19 @@ def _check_style_compatibility(self, ctx: Context, style: str, value: t.Any) -> def convert(self, ctx: Context) -> t.Any: datatype = ctx.datatype union: t.Optional[Union] + literal_types: t.List[TypeHint] = [] + if isinstance(datatype, UnionTypeHint): if datatype.has_none_type(): raise NotImplementedError("unable to handle Union type with None in it") - if not all(isinstance(a, ClassTypeHint) for a in datatype): - raise NotImplementedError(f"members of plain Union must be concrete types: {datatype}") - members = {t.cast(ClassTypeHint, a).type.__name__: a for a in datatype} - if len(members) != len(datatype): + + literal_types = [a for a in datatype if isinstance(a, LiteralTypeHint)] + non_literal_types = [a for a in datatype if not isinstance(a, LiteralTypeHint)] + if not all(isinstance(a, ClassTypeHint) for a in non_literal_types): + raise NotImplementedError(f"members of plain Union must be concrete or Literal types: {datatype}") + + members = {t.cast(ClassTypeHint, a).type.__name__: a for a in non_literal_types} + if len(members) != len(non_literal_types): raise NotImplementedError(f"members of plain Union cannot have overlapping type names: {datatype}") union = Union(members, Union.BEST_MATCH) elif isinstance(datatype, (AnnotatedTypeHint, ClassTypeHint)): @@ -788,6 +794,11 @@ def convert(self, ctx: Context) -> t.Any: return ctx.spawn(ctx.value, member_type, None).convert() except ConversionError as exc: errors.append((exc.origin, exc)) + for literal_type in literal_types: + try: + return ctx.spawn(ctx.value, literal_type, None).convert() + except ConversionError as exc: + errors.append((exc.origin, exc)) raise ConversionError( self, ctx, diff --git a/databind/src/databind/json/tests/converters_test.py b/databind/src/databind/json/tests/converters_test.py index 0de421c..2b6cadf 100644 --- a/databind/src/databind/json/tests/converters_test.py +++ b/databind/src/databind/json/tests/converters_test.py @@ -30,6 +30,7 @@ DatetimeConverter, DecimalConverter, EnumConverter, + LiteralConverter, MappingConverter, OptionalConverter, PlainDatatypeConverter, @@ -328,6 +329,22 @@ def test_union_converter_best_match(direction: Direction) -> None: assert mapper.convert(direction, 42, t.Union[int, str]) == 42 +@pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE)) +def test_union_converter_best_match_literal(direction: Direction) -> None: + mapper = make_mapper([UnionConverter(), PlainDatatypeConverter(), LiteralConverter()]) + + LiteralUnionType = t.Union[int, t.Literal["hi"], t.Literal["bye"]] + + if direction == Direction.DESERIALIZE: + assert mapper.convert(direction, 42, LiteralUnionType) == 42 + assert mapper.convert(direction, "hi", LiteralUnionType) == "hi" + assert mapper.convert(direction, "bye", LiteralUnionType) == "bye" + else: + assert mapper.convert(direction, 42, LiteralUnionType) == 42 + assert mapper.convert(direction, "hi", LiteralUnionType) == "hi" + assert mapper.convert(direction, "bye", LiteralUnionType) == "bye" + + @pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE)) def test_union_converter_keyed(direction: Direction) -> None: mapper = make_mapper([UnionConverter(), PlainDatatypeConverter()]) @@ -339,6 +356,31 @@ def test_union_converter_keyed(direction: Direction) -> None: assert mapper.convert(direction, 42, th) == {"int": 42} +@pytest.mark.xfail +@pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE)) +def test_union_converter_keyed_literal(direction: Direction) -> None: + mapper = make_mapper([UnionConverter(), PlainDatatypeConverter(), LiteralConverter()]) + + th = te.Annotated[ + t.Union[int, t.Literal["hi"], t.Literal["bye"]], + Union({"int": int, "HiType": t.Literal["hi"], "ByeType": t.Literal["bye"]}, style=Union.KEYED), + ] + if direction == Direction.DESERIALIZE: + assert mapper.convert(direction, {"int": 42}, th) == 42 + assert mapper.convert(direction, {"HiType": "hi"}, th) == "hi" + assert mapper.convert(direction, {"ByeType": "bye"}, th) == "bye" + + with pytest.raises(ConversionError): + mapper.convert(direction, {"ByeType": "hi"}, th) + else: + assert mapper.convert(direction, 42, th) == {"int": 42} + assert mapper.convert(direction, "hi", th) == {"HiType": "hi"} + assert mapper.convert(direction, "bye", th) == {"ByeType": "bye"} + + with pytest.raises(ConversionError): + mapper.convert(direction, {"ByeType": "hi"}, th) + + @pytest.mark.parametrize("direction", (Direction.SERIALIZE, Direction.DESERIALIZE)) def test_union_converter_flat_plain_types_not_supported(direction: Direction) -> None: mapper = make_mapper([UnionConverter(), PlainDatatypeConverter()])