|
| 1 | +import typing |
1 | 2 | from contextlib import suppress |
2 | 3 | from functools import lru_cache |
3 | 4 | from itertools import groupby |
|
6 | 7 | from typing import Iterable |
7 | 8 | from typing import List |
8 | 9 | from typing import Optional |
| 10 | +from typing import Tuple |
9 | 11 | from typing import Type |
10 | 12 | from typing import Union |
11 | 13 |
|
12 | 14 | from pydantic.error_wrappers import ValidationError |
13 | | -from pydantic.fields import FieldInfo |
14 | 15 | from typing_extensions import get_args |
15 | 16 | from typing_extensions import get_origin |
16 | 17 |
|
|
21 | 22 | IntrospectionError = (KeyError, IndexError, AttributeError) |
22 | 23 |
|
23 | 24 |
|
24 | | -def _extract_root_type(storage_type: Type) -> Type: |
| 25 | +def extract_root_outer_type(storage_type: Type) -> Type: |
25 | 26 | """Extract Pydantic __root__ type""" |
26 | | - return storage_type.__fields__['__root__'].type_ |
| 27 | + root_field = storage_type.__fields__['__root__'] |
| 28 | + if root_field.allow_none: |
| 29 | + return typing.Optional[root_field.type_] # type: ignore |
| 30 | + else: |
| 31 | + return root_field.outer_type_ |
27 | 32 |
|
28 | 33 |
|
29 | 34 | @lru_cache(None) |
30 | | -def _is_array(storage_type: Type) -> bool: |
| 35 | +def is_array_type(storage_type: Type) -> bool: |
31 | 36 | """TzKT can return bigmaps as objects or as arrays of key-value objects. Guess it from storage type.""" |
32 | 37 | # NOTE: List[...] |
33 | 38 | if get_origin(storage_type) == list: |
34 | 39 | return True |
35 | 40 |
|
36 | | - # NOTE: Neither a list not Pydantic model, can't be an array |
37 | | - fields: Optional[Dict[str, FieldInfo]] = getattr(storage_type, '__fields__', None) |
38 | | - if fields is None: |
39 | | - return False |
40 | | - |
41 | | - # NOTE: An item of TzKT array |
42 | | - if 'key' in fields and 'value' in fields: |
43 | | - return True |
44 | | - |
45 | | - # NOTE: Pydantic model with __root__ field, dive into it |
| 41 | + # NOTE: Pydantic model with __root__ field subclassing List |
46 | 42 | with suppress(*IntrospectionError): |
47 | | - root_type = _extract_root_type(storage_type) |
48 | | - return _is_array(root_type) # type: ignore |
| 43 | + root_type = extract_root_outer_type(storage_type) |
| 44 | + return is_array_type(root_type) # type: ignore |
49 | 45 |
|
50 | 46 | # NOTE: Something else |
51 | 47 | return False |
52 | 48 |
|
53 | 49 |
|
54 | 50 | @lru_cache(None) |
55 | | -def _extract_list_types(storage_type: Type[Any]) -> Iterable[Type[Any]]: |
56 | | - """Extract list item types from field type""" |
57 | | - # NOTE: Pydantic model with __root__ field |
58 | | - with suppress(*IntrospectionError): |
59 | | - return (_extract_root_type(storage_type),) |
| 51 | +def get_list_elt_type(list_type: Type[Any]) -> Type[Any]: |
| 52 | + """Extract list item type from list type""" |
| 53 | + # NOTE: regular list |
| 54 | + if get_origin(list_type) == list: |
| 55 | + return get_args(list_type)[0] |
60 | 56 |
|
61 | | - # NOTE: Python list, return all args unpacking unions |
62 | | - with suppress(*IntrospectionError): |
63 | | - item_type = get_args(storage_type)[0] |
64 | | - if get_origin(item_type) == Union: |
65 | | - return get_args(item_type) |
66 | | - return (item_type,) |
67 | | - |
68 | | - # NOTE: Something else |
69 | | - return () |
| 57 | + # NOTE: Pydantic model with __root__ field subclassing List |
| 58 | + root_type = extract_root_outer_type(list_type) |
| 59 | + return get_list_elt_type(root_type) # type: ignore |
70 | 60 |
|
71 | 61 |
|
72 | 62 | @lru_cache(None) |
73 | | -def _extract_dict_types(storage_type: Type[Any], key: str) -> Iterable[Type[Any]]: |
| 63 | +def get_dict_value_type(dict_type: Type[Any], key: Optional[str] = None) -> Type[Any]: |
74 | 64 | """Extract dict value types from field type""" |
75 | 65 | # NOTE: Regular dict |
76 | | - if get_origin(storage_type) == dict: |
77 | | - return (get_args(storage_type)[1],) |
| 66 | + if get_origin(dict_type) == dict: |
| 67 | + return get_args(dict_type)[1] |
| 68 | + |
| 69 | + # NOTE: Pydantic model with __root__ field subclassing Dict |
| 70 | + with suppress(*IntrospectionError): |
| 71 | + root_type = extract_root_outer_type(dict_type) |
| 72 | + return get_dict_value_type(root_type, key) # type: ignore |
78 | 73 |
|
79 | | - # NOTE: Unpack union args |
80 | | - if get_origin(storage_type) == Union: |
81 | | - return get_args(storage_type) |
| 74 | + if key is None: |
| 75 | + raise KeyError('Key name or alias is required for object introspection') |
82 | 76 |
|
83 | 77 | # NOTE: Pydantic model, find corresponding field and return it's type |
| 78 | + fields = dict_type.__fields__ |
| 79 | + for field in fields.values(): |
| 80 | + if key in (field.name, field.alias): |
| 81 | + # NOTE: Pydantic does not preserve outer_type_ for Optional |
| 82 | + if field.allow_none: |
| 83 | + return typing.Optional[field.type_] # type: ignore |
| 84 | + else: |
| 85 | + return field.outer_type_ |
| 86 | + |
| 87 | + # NOTE: typically when we try the wrong Union path |
| 88 | + raise KeyError('Key not found') |
| 89 | + |
| 90 | + |
| 91 | +@lru_cache(None) |
| 92 | +def unwrap_union_type(union_type: Type) -> Tuple[bool, Tuple[Type, ...]]: |
| 93 | + """Check if the type is either optional or union and return arg types if so""" |
| 94 | + if get_origin(union_type) == Union: |
| 95 | + return True, get_args(union_type) |
| 96 | + |
84 | 97 | with suppress(*IntrospectionError): |
85 | | - fields = storage_type.__fields__ |
86 | | - for field in fields.values(): |
87 | | - if key in (field.name, field.alias): |
88 | | - return (field.type_,) |
| 98 | + root_type = extract_root_outer_type(union_type) |
| 99 | + return unwrap_union_type(root_type) # type: ignore |
89 | 100 |
|
90 | | - # NOTE: Something else |
91 | | - return () |
| 101 | + return False, () |
92 | 102 |
|
93 | 103 |
|
94 | 104 | def _preprocess_bigmap_diffs(diffs: Iterable[Dict[str, Any]]) -> Dict[int, Iterable[Dict[str, Any]]]: |
@@ -124,30 +134,32 @@ def _apply_bigmap_diffs( |
124 | 134 | return dict_storage |
125 | 135 |
|
126 | 136 |
|
127 | | -def _process_storage( |
128 | | - storage: Any, |
129 | | - storage_type: Type[StorageType], |
130 | | - bigmap_diffs: Dict[int, Iterable[Dict[str, Any]]], |
131 | | -) -> Any: |
| 137 | +def _process_storage(storage: Any, storage_type: Type[Any], bigmap_diffs: Dict[int, Iterable[Dict[str, Any]]]) -> Any: |
132 | 138 | """Replace bigmap pointers with actual data from diffs""" |
| 139 | + # Check if Union or Optional (== Union[Any, NoneType]) |
| 140 | + is_union, arg_types = unwrap_union_type(storage_type) # type: ignore |
| 141 | + if is_union: |
| 142 | + # NOTE: We have no way but trying every possible branch until first success |
| 143 | + for arg_type in arg_types: |
| 144 | + with suppress(*IntrospectionError): |
| 145 | + return _process_storage(storage, arg_type, bigmap_diffs) |
| 146 | + |
133 | 147 | # NOTE: Bigmap pointer, apply diffs |
134 | 148 | if isinstance(storage, int) and type(storage) != storage_type: |
135 | | - is_array = _is_array(storage_type) |
| 149 | + is_array = is_array_type(storage_type) # type: ignore |
136 | 150 | storage = _apply_bigmap_diffs(storage, bigmap_diffs, is_array) |
137 | 151 |
|
138 | 152 | # NOTE: List, process recursively |
139 | 153 | elif isinstance(storage, list): |
| 154 | + elt_type = get_list_elt_type(storage_type) # type: ignore |
140 | 155 | for i, _ in enumerate(storage): |
141 | | - for item_type in _extract_list_types(storage_type): |
142 | | - with suppress(*IntrospectionError): |
143 | | - storage[i] = _process_storage(storage[i], item_type, bigmap_diffs) |
| 156 | + storage[i] = _process_storage(storage[i], elt_type, bigmap_diffs) |
144 | 157 |
|
145 | 158 | # NOTE: Dict, process recursively |
146 | 159 | elif isinstance(storage, dict): |
147 | 160 | for key, value in storage.items(): |
148 | | - for value_type in _extract_dict_types(storage_type, key): |
149 | | - with suppress(*IntrospectionError): |
150 | | - storage[key] = _process_storage(value, value_type, bigmap_diffs) |
| 161 | + value_type = get_dict_value_type(storage_type, key) # type: ignore |
| 162 | + storage[key] = _process_storage(value, value_type, bigmap_diffs) |
151 | 163 |
|
152 | 164 | else: |
153 | 165 | pass |
|
0 commit comments