Skip to content

Commit 432b442

Browse files
authored
New disambiguator tweaks (#435)
* Tweak docs * Docs and tests * Restore default * Remove unused import
1 parent 43f7d0f commit 432b442

File tree

5 files changed

+239
-113
lines changed

5 files changed

+239
-113
lines changed

HISTORY.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@
88
- **Potentially breaking**: {py:func}`cattrs.gen.make_dict_structure_fn` and {py:func}`cattrs.gen.typeddicts.make_dict_structure_fn` will use the values for the `detailed_validation` and `forbid_extra_keys` parameters from the given converter by default now.
99
If you're using these functions directly, the old behavior can be restored by passing in the desired values directly.
1010
([#410](https://github.com/python-attrs/cattrs/issues/410) [#411](https://github.com/python-attrs/cattrs/pull/411))
11+
- **Potentially breaking**: The default union structuring strategy will also use fields annotated as `typing.Literal` to help guide structuring.
12+
([#391](https://github.com/python-attrs/cattrs/pull/391))
1113
- Python 3.12 is now supported. Python 3.7 is no longer supported; use older releases there.
1214
([#424](https://github.com/python-attrs/cattrs/pull/424))
15+
- Implement the `union passthrough` strategy, enabling much richer union handling for preconfigured converters. [Learn more here](https://catt.rs/en/stable/strategies.html#union-passthrough).
1316
- Introduce the `use_class_methods` strategy. Learn more [here](https://catt.rs/en/latest/strategies.html#using-class-specific-structure-and-unstructure-methods).
1417
([#405](https://github.com/python-attrs/cattrs/pull/405))
15-
- Implement the `union passthrough` strategy, enabling much richer union handling for preconfigured converters. [Learn more here](https://catt.rs/en/stable/strategies.html#union-passthrough).
1618
- The `omit` parameter of {py:func}`cattrs.override` is now of type `bool | None` (from `bool`).
1719
`None` is the new default and means to apply default _cattrs_ handling to the attribute, which is to omit the attribute if it's marked as `init=False`, and keep it otherwise.
1820
- Fix {py:func}`format_exception() <cattrs.v.format_exception>` parameter working for recursive calls to {py:func}`transform_error <cattrs.transform_error>`.
1921
([#389](https://github.com/python-attrs/cattrs/issues/389))
2022
- [_attrs_ aliases](https://www.attrs.org/en/stable/init.html#private-attributes-and-aliases) are now supported, although aliased fields still map to their attribute name instead of their alias by default when un/structuring.
2123
([#322](https://github.com/python-attrs/cattrs/issues/322) [#391](https://github.com/python-attrs/cattrs/pull/391))
22-
- Use [PDM](https://pdm.fming.dev/latest/) instead of Poetry.
23-
- _cattrs_ is now linted with [Ruff](https://beta.ruff.rs/docs/).
2424
- Fix TypedDicts with periods in their field names.
2525
([#376](https://github.com/python-attrs/cattrs/issues/376) [#377](https://github.com/python-attrs/cattrs/pull/377))
2626
- Optimize and improve unstructuring of `Optional` (unions of one type and `None`).
@@ -45,10 +45,10 @@
4545
([#420](https://github.com/python-attrs/cattrs/pull/420))
4646
- Add support for `datetime.date`s to the PyYAML preconfigured converter.
4747
([#393](https://github.com/python-attrs/cattrs/issues/393))
48+
- Use [PDM](https://pdm.fming.dev/latest/) instead of Poetry.
49+
- _cattrs_ is now linted with [Ruff](https://beta.ruff.rs/docs/).
4850
- Remove some unused lines in the unstructuring code.
4951
([#416](https://github.com/python-attrs/cattrs/pull/416))
50-
- Disambiguate a union of attrs classes where there's a `typing.Literal` tag of some sort.
51-
([#391](https://github.com/python-attrs/cattrs/pull/391))
5252

5353
## 23.1.2 (2023-06-02)
5454

docs/unions.md

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,75 @@
22

33
This sections contains information for advanced union handling.
44

5-
As mentioned in the structuring section, _cattrs_ is able to handle simple unions of _attrs_ classes automatically.
5+
As mentioned in the structuring section, _cattrs_ is able to handle simple unions of _attrs_ classes [automatically](#default-union-strategy).
66
More complex cases require converter customization (since there are many ways of handling unions).
77

88
_cattrs_ also comes with a number of strategies to help handle unions:
99

1010
- [tagged unions strategy](strategies.md#tagged-unions-strategy) mentioned below
1111
- [union passthrough strategy](strategies.md#union-passthrough), which is preapplied to all the [preconfigured](preconf.md) converters
1212

13-
## Unstructuring unions with extra metadata
13+
## Default Union Strategy
14+
15+
For convenience, _cattrs_ includes a default union structuring strategy which is a little more opinionated.
16+
17+
Given a union of several _attrs_ classes, the default union strategy will attempt to handle it in several ways.
18+
19+
First, it will look for `Literal` fields.
20+
If all members of the union contain a literal field, _cattrs_ will generate a disambiguation function based on the field.
21+
22+
```python
23+
from typing import Literal
24+
25+
@define
26+
class ClassA:
27+
field_one: Literal["one"]
28+
29+
@define
30+
class ClassB:
31+
field_one: Literal["two"]
32+
```
33+
34+
In this case, a payload containing `{"field_one": "one"}` will produce an instance of `ClassA`.
35+
36+
````{note}
37+
The following snippet can be used to disable the use of literal fields, restoring the previous behavior.
38+
39+
```python
40+
from functools import partial
41+
from cattrs.disambiguators import is_supported_union
42+
43+
converter.register_structure_hook_factory(
44+
is_supported_union,
45+
partial(converter._gen_attrs_union_structure, use_literals=False),
46+
)
47+
```
48+
49+
````
50+
51+
If there are no appropriate fields, the strategy will examine the classes for **unique required fields**.
52+
53+
So, given a union of `ClassA` and `ClassB`:
54+
55+
```python
56+
@define
57+
class ClassA:
58+
field_one: str
59+
field_with_default: str = "a default"
60+
61+
@define
62+
class ClassB:
63+
field_two: str
64+
```
65+
66+
the strategy will determine that if a payload contains the key `field_one` it should be handled as `ClassA`, and if it contains the key `field_two` it should be handled as `ClassB`.
67+
The field `field_with_default` will not be considered since it has a default value, so it gets treated as optional.
68+
69+
```{versionchanged} 23.2.0
70+
Literals can now be potentially used to disambiguate.
71+
```
72+
73+
## Unstructuring Unions with Extra Metadata
1474

1575
```{note}
1676
_cattrs_ comes with the [tagged unions strategy](strategies.md#tagged-unions-strategy) for handling this exact use-case since version 23.1.

src/cattrs/converters.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
Union,
2020
)
2121

22-
from attr import Attribute
23-
from attr import has as attrs_has
24-
from attr import resolve_types
22+
from attrs import Attribute
23+
from attrs import has as attrs_has
24+
from attrs import resolve_types
2525

2626
from ._compat import (
2727
FrozenSetSubscriptable,
@@ -55,7 +55,7 @@
5555
is_typeddict,
5656
is_union_type,
5757
)
58-
from .disambiguators import create_default_dis_func
58+
from .disambiguators import create_default_dis_func, is_supported_union
5959
from .dispatch import MultiStrategyDispatch
6060
from .errors import (
6161
IterableValidationError,
@@ -96,16 +96,6 @@ def _subclass(typ: Type) -> Callable[[Type], bool]:
9696
return lambda cls: issubclass(cls, typ)
9797

9898

99-
def is_attrs_union(typ: Type) -> bool:
100-
return is_union_type(typ) and all(has(get_origin(e) or e) for e in typ.__args__)
101-
102-
103-
def is_attrs_union_or_none(typ: Type) -> bool:
104-
return is_union_type(typ) and all(
105-
e is NoneType or has(get_origin(e) or e) for e in typ.__args__
106-
)
107-
108-
10999
def is_optional(typ: Type) -> bool:
110100
return is_union_type(typ) and NoneType in typ.__args__ and len(typ.__args__) == 2
111101

@@ -204,7 +194,7 @@ def __init__(
204194
(is_frozenset, self._structure_frozenset),
205195
(is_tuple, self._structure_tuple),
206196
(is_mapping, self._structure_dict),
207-
(is_attrs_union_or_none, self._gen_attrs_union_structure, True),
197+
(is_supported_union, self._gen_attrs_union_structure, True),
208198
(
209199
lambda t: is_union_type(t) and t in self._union_struct_registry,
210200
self._structure_union,
@@ -411,17 +401,19 @@ def _gen_structure_generic(self, cl: Type[T]) -> DictStructureFn[T]:
411401
)
412402

413403
def _gen_attrs_union_structure(
414-
self, cl: Any
404+
self, cl: Any, use_literals: bool = True
415405
) -> Callable[[Any, Type[T]], Optional[Type[T]]]:
416406
"""
417407
Generate a structuring function for a union of attrs classes (and maybe None).
408+
409+
:param use_literals: Whether to consider literal fields.
418410
"""
419-
dis_fn = self._get_dis_func(cl)
411+
dis_fn = self._get_dis_func(cl, use_literals=use_literals)
420412
has_none = NoneType in cl.__args__
421413

422414
if has_none:
423415

424-
def structure_attrs_union(obj, _):
416+
def structure_attrs_union(obj, _) -> cl:
425417
if obj is None:
426418
return None
427419
return self.structure(obj, dis_fn(obj))
@@ -719,7 +711,7 @@ def _structure_tuple(self, obj: Any, tup: Type[T]) -> T:
719711
return res
720712

721713
@staticmethod
722-
def _get_dis_func(union: Any) -> Callable[[Any], Type]:
714+
def _get_dis_func(union: Any, use_literals: bool = True) -> Callable[[Any], Type]:
723715
"""Fetch or try creating a disambiguation function for a union."""
724716
union_types = union.__args__
725717
if NoneType in union_types: # type: ignore
@@ -738,7 +730,7 @@ def _get_dis_func(union: Any) -> Callable[[Any], Type]:
738730
type_=union,
739731
)
740732

741-
return create_default_dis_func(*union_types)
733+
return create_default_dis_func(*union_types, use_literals=use_literals)
742734

743735
def __deepcopy__(self, _) -> "BaseConverter":
744736
return self.copy()

src/cattrs/disambiguators.py

Lines changed: 67 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,71 +2,86 @@
22
from collections import OrderedDict, defaultdict
33
from functools import reduce
44
from operator import or_
5-
from typing import Any, Callable, Dict, Mapping, Optional, Type, Union
5+
from typing import Any, Callable, Dict, Mapping, Optional, Set, Type, Union
66

7-
from attr import NOTHING, fields, fields_dict
7+
from attrs import NOTHING, fields, fields_dict
88

9-
from cattrs._compat import get_args, get_origin, is_literal
9+
from ._compat import get_args, get_origin, has, is_literal, is_union_type
10+
11+
__all__ = ("is_supported_union", "create_default_dis_func")
12+
13+
NoneType = type(None)
14+
15+
16+
def is_supported_union(typ: Type) -> bool:
17+
"""Whether the type is a union of attrs classes."""
18+
return is_union_type(typ) and all(
19+
e is NoneType or has(get_origin(e) or e) for e in typ.__args__
20+
)
1021

1122

1223
def create_default_dis_func(
13-
*classes: Type[Any],
24+
*classes: Type[Any], use_literals: bool = True
1425
) -> Callable[[Mapping[Any, Any]], Optional[Type[Any]]]:
15-
"""Given attr classes, generate a disambiguation function.
26+
"""Given attrs classes, generate a disambiguation function.
27+
28+
The function is based on unique fields or unique values.
1629
17-
The function is based on unique fields or unique values."""
30+
:param use_literals: Whether to try using fields annotated as literals for
31+
disambiguation.
32+
"""
1833
if len(classes) < 2:
1934
raise ValueError("At least two classes required.")
2035

2136
# first, attempt for unique values
37+
if use_literals:
38+
# requirements for a discriminator field:
39+
# (... TODO: a single fallback is OK)
40+
# - it must always be enumerated
41+
cls_candidates = [
42+
{at.name for at in fields(get_origin(cl) or cl) if is_literal(at.type)}
43+
for cl in classes
44+
]
45+
46+
# literal field names common to all members
47+
discriminators: Set[str] = cls_candidates[0]
48+
for possible_discriminators in cls_candidates:
49+
discriminators &= possible_discriminators
50+
51+
best_result = None
52+
best_discriminator = None
53+
for discriminator in discriminators:
54+
# maps Literal values (strings, ints...) to classes
55+
mapping = defaultdict(list)
56+
57+
for cl in classes:
58+
for key in get_args(
59+
fields_dict(get_origin(cl) or cl)[discriminator].type
60+
):
61+
mapping[key].append(cl)
62+
63+
if best_result is None or max(len(v) for v in mapping.values()) <= max(
64+
len(v) for v in best_result.values()
65+
):
66+
best_result = mapping
67+
best_discriminator = discriminator
68+
69+
if (
70+
best_result
71+
and best_discriminator
72+
and max(len(v) for v in best_result.values()) != len(classes)
73+
):
74+
final_mapping = {
75+
k: v[0] if len(v) == 1 else Union[tuple(v)]
76+
for k, v in best_result.items()
77+
}
2278

23-
# requirements for a discriminator field:
24-
# (... TODO: a single fallback is OK)
25-
# - it must be *required*
26-
# - it must always be enumerated
27-
cls_candidates = [
28-
{
29-
at.name
30-
for at in fields(get_origin(cl) or cl)
31-
if at.default is NOTHING and is_literal(at.type)
32-
}
33-
for cl in classes
34-
]
35-
36-
discriminators = cls_candidates[0]
37-
for possible_discriminators in cls_candidates:
38-
discriminators &= possible_discriminators
39-
40-
best_result = None
41-
best_discriminator = None
42-
for discriminator in discriminators:
43-
mapping = defaultdict(list)
44-
45-
for cl in classes:
46-
for key in get_args(fields_dict(get_origin(cl) or cl)[discriminator].type):
47-
mapping[key].append(cl)
79+
def dis_func(data: Mapping[Any, Any]) -> Optional[Type]:
80+
if not isinstance(data, Mapping):
81+
raise ValueError("Only input mappings are supported.")
82+
return final_mapping[data[best_discriminator]]
4883

49-
if best_result is None or max(len(v) for v in mapping.values()) <= max(
50-
len(v) for v in best_result.values()
51-
):
52-
best_result = mapping
53-
best_discriminator = discriminator
54-
55-
if (
56-
best_result
57-
and best_discriminator
58-
and max(len(v) for v in best_result.values()) != len(classes)
59-
):
60-
final_mapping = {
61-
k: v[0] if len(v) == 1 else Union[tuple(v)] for k, v in best_result.items()
62-
}
63-
64-
def dis_func(data: Mapping[Any, Any]) -> Optional[Type]:
65-
if not isinstance(data, Mapping):
66-
raise ValueError("Only input mappings are supported.")
67-
return final_mapping[data[best_discriminator]]
68-
69-
return dis_func
84+
return dis_func
7085

7186
# next, attempt for unique keys
7287

0 commit comments

Comments
 (0)