|
9 | 9 | import inspect |
10 | 10 | import warnings |
11 | 11 | from enum import Enum |
12 | | -from typing import Any, Callable, Mapping, Type, get_origin |
| 12 | +from typing import Any, Callable, Mapping, Sequence, Type, get_origin |
13 | 13 |
|
14 | 14 | import numpy as np |
15 | 15 |
|
@@ -170,6 +170,37 @@ def encode_basic_value(value: Any) -> Any: |
170 | 170 | return encode_basic_value |
171 | 171 |
|
172 | 172 |
|
| 173 | +def make_engine_key_decoder( |
| 174 | + field_path: list[str], |
| 175 | + key_fields_schema: list[dict[str, Any]], |
| 176 | + dst_type_info: AnalyzedTypeInfo, |
| 177 | +) -> Callable[[Any], Any]: |
| 178 | + """ |
| 179 | + Create an encoder closure for a key type. |
| 180 | + """ |
| 181 | + if len(key_fields_schema) == 1 and isinstance( |
| 182 | + dst_type_info.variant, (AnalyzedBasicType, AnalyzedAnyType) |
| 183 | + ): |
| 184 | + single_key_decoder = make_engine_value_decoder( |
| 185 | + field_path, |
| 186 | + key_fields_schema[0]["type"], |
| 187 | + dst_type_info, |
| 188 | + for_key=True, |
| 189 | + ) |
| 190 | + |
| 191 | + def key_decoder(value: list[Any]) -> Any: |
| 192 | + return single_key_decoder(value[0]) |
| 193 | + |
| 194 | + return key_decoder |
| 195 | + |
| 196 | + return make_engine_struct_decoder( |
| 197 | + field_path, |
| 198 | + key_fields_schema, |
| 199 | + dst_type_info, |
| 200 | + for_key=True, |
| 201 | + ) |
| 202 | + |
| 203 | + |
173 | 204 | def make_engine_value_decoder( |
174 | 205 | field_path: list[str], |
175 | 206 | src_type: dict[str, Any], |
@@ -244,31 +275,11 @@ def decode(value: Any) -> Any | None: |
244 | 275 | ) |
245 | 276 |
|
246 | 277 | num_key_parts = src_type.get("num_key_parts", 1) |
247 | | - key_type_info = analyze_type_info(key_type) |
248 | | - key_decoder: Callable[..., Any] | None = None |
249 | | - if ( |
250 | | - isinstance( |
251 | | - key_type_info.variant, (AnalyzedBasicType, AnalyzedAnyType) |
252 | | - ) |
253 | | - and num_key_parts == 1 |
254 | | - ): |
255 | | - single_key_decoder = make_engine_value_decoder( |
256 | | - field_path, |
257 | | - engine_fields_schema[0]["type"], |
258 | | - key_type_info, |
259 | | - for_key=True, |
260 | | - ) |
261 | | - |
262 | | - def key_decoder(value: list[Any]) -> Any: |
263 | | - return single_key_decoder(value[0]) |
264 | | - |
265 | | - else: |
266 | | - key_decoder = make_engine_struct_decoder( |
267 | | - field_path, |
268 | | - engine_fields_schema[0:num_key_parts], |
269 | | - key_type_info, |
270 | | - for_key=True, |
271 | | - ) |
| 278 | + key_decoder = make_engine_key_decoder( |
| 279 | + field_path, |
| 280 | + engine_fields_schema[0:num_key_parts], |
| 281 | + analyze_type_info(key_type), |
| 282 | + ) |
272 | 283 | value_decoder = make_engine_struct_decoder( |
273 | 284 | field_path, |
274 | 285 | engine_fields_schema[num_key_parts:], |
|
0 commit comments