Skip to content

Commit e82772b

Browse files
committed
feat(convert): create an encoder closure for efficient type converting
1 parent ab652da commit e82772b

File tree

2 files changed

+270
-93
lines changed

2 files changed

+270
-93
lines changed

python/cocoindex/convert.py

Lines changed: 264 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -64,106 +64,287 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool
6464
)
6565

6666

67-
def _get_type_info_safe(type_to_analyze: Any) -> AnalyzedTypeInfo:
67+
def _get_type_info(type_to_analyze: Any) -> AnalyzedTypeInfo:
6868
"""Safely get type info, bypassing cache if type is not hashable."""
6969

7070
@functools.cache
71-
def _get_cached_type_info() -> AnalyzedTypeInfo:
71+
def _get_cached_type_info(t: Any) -> AnalyzedTypeInfo:
7272
"""cache the computed type information for a given type."""
73-
return analyze_type_info(type_to_analyze)
73+
return analyze_type_info(t)
7474

7575
try:
7676
return _get_cached_type_info(type_to_analyze)
77-
except TypeError:
77+
except TypeError: # The type is not hashable
7878
return analyze_type_info(type_to_analyze)
7979

8080

81-
def _encode_engine_value_core(
82-
value: Any,
83-
type_info: AnalyzedTypeInfo | None = None,
84-
) -> Any:
85-
"""Core encoding logic for converting Python values to engine values."""
86-
87-
if dataclasses.is_dataclass(value):
88-
fields = dataclasses.fields(value)
89-
return [
90-
_encode_engine_value_core(
91-
getattr(value, f.name),
92-
type_info=_get_type_info_safe(f.type),
81+
def _make_encoder_closure(type_info: AnalyzedTypeInfo | None) -> Callable[[Any], Any]:
82+
"""
83+
Create an encoder closure for a specific type.
84+
"""
85+
if type_info is None:
86+
# For untyped encoding, fall back to basic logic
87+
def encode_untyped(value: Any) -> Any:
88+
if dataclasses.is_dataclass(value):
89+
fields = dataclasses.fields(value)
90+
return [
91+
_make_encoder_closure(_get_type_info(f.type))(
92+
getattr(value, f.name)
93+
)
94+
for f in fields
95+
]
96+
97+
if is_namedtuple_type(type(value)):
98+
annotations = type(value).__annotations__
99+
return [
100+
_make_encoder_closure(
101+
_get_type_info(annotations.get(name))
102+
if annotations.get(name)
103+
else None
104+
)(getattr(value, name))
105+
for name in value._fields
106+
]
107+
108+
if isinstance(value, np.number):
109+
return value.item()
110+
111+
if isinstance(value, np.ndarray):
112+
return value
113+
114+
if isinstance(value, (list, tuple)):
115+
return [_make_encoder_closure(None)(v) for v in value]
116+
117+
if isinstance(value, dict):
118+
# Handle empty dict
119+
if not value:
120+
return value
121+
122+
# Handle KTable
123+
first_val = next(iter(value.values()))
124+
if is_struct_type(type(first_val)):
125+
return [
126+
[_make_encoder_closure(None)(k)]
127+
+ _make_encoder_closure(None)(v)
128+
for k, v in value.items()
129+
]
130+
131+
return value
132+
133+
return encode_untyped
134+
135+
variant = type_info.variant
136+
137+
# Handle JSON types
138+
if isinstance(variant, AnalyzedBasicType) and variant.kind == "Json":
139+
140+
def encode_json_dict(value: Any) -> Any:
141+
if isinstance(value, dict):
142+
# Handle empty dict
143+
if not value:
144+
return value
145+
146+
# Handle KTable
147+
first_val = next(iter(value.values()))
148+
if is_struct_type(type(first_val)):
149+
untyped_encoder = _make_encoder_closure(None)
150+
return [
151+
[untyped_encoder(k)] + untyped_encoder(v)
152+
for k, v in value.items()
153+
]
154+
155+
return value
156+
157+
return encode_json_dict
158+
159+
# Handle Any types and special numpy cases
160+
if isinstance(variant, AnalyzedAnyType):
161+
162+
def encode_any_type(value: Any) -> Any:
163+
# Handle numpy types first
164+
if isinstance(value, np.number):
165+
return value.item()
166+
if isinstance(value, np.ndarray):
167+
return value
168+
169+
# Handle tuples - convert to lists for Any type
170+
if isinstance(value, tuple):
171+
return [_make_encoder_closure(None)(v) for v in value]
172+
173+
# Handle dataclasses
174+
if dataclasses.is_dataclass(value):
175+
fields = dataclasses.fields(value)
176+
return [
177+
_make_encoder_closure(_get_type_info(f.type))(
178+
getattr(value, f.name)
179+
)
180+
for f in fields
181+
]
182+
183+
# Handle namedtuples
184+
if is_namedtuple_type(type(value)):
185+
annotations = type(value).__annotations__
186+
return [
187+
_make_encoder_closure(
188+
_get_type_info(annotations.get(name))
189+
if annotations.get(name)
190+
else None
191+
)(getattr(value, name))
192+
for name in value._fields
193+
]
194+
195+
# Handle lists
196+
if isinstance(value, list):
197+
return [_make_encoder_closure(None)(v) for v in value]
198+
199+
# Handle dicts
200+
if isinstance(value, dict):
201+
# Handle empty dict
202+
if not value:
203+
return value
204+
205+
# Handle KTable
206+
first_val = next(iter(value.values()))
207+
if is_struct_type(type(first_val)):
208+
return [
209+
[_make_encoder_closure(None)(k)]
210+
+ _make_encoder_closure(None)(v)
211+
for k, v in value.items()
212+
]
213+
214+
return value
215+
216+
return encode_any_type
217+
218+
# Handle basic types
219+
if isinstance(variant, AnalyzedBasicType):
220+
221+
def encode_basic_with_numpy(value: Any) -> Any:
222+
# Handle numpy types for basic types
223+
if isinstance(value, np.number):
224+
return value.item()
225+
if isinstance(value, np.ndarray):
226+
return value
227+
return value
228+
229+
return encode_basic_with_numpy
230+
231+
# Handle lists
232+
if isinstance(variant, AnalyzedListType):
233+
if variant.elem_type:
234+
elem_encoder = _make_encoder_closure(_get_type_info(variant.elem_type))
235+
return (
236+
lambda value: [elem_encoder(v) for v in value]
237+
if isinstance(value, (list, tuple))
238+
else value
93239
)
94-
for f in fields
95-
]
96-
97-
if is_namedtuple_type(type(value)):
98-
annotations = type(value).__annotations__
99-
return [
100-
_encode_engine_value_core(
101-
getattr(value, name),
102-
type_info=_get_type_info_safe(annotations.get(name))
103-
if annotations.get(name)
104-
else None,
240+
else:
241+
fallback_encoder = _make_encoder_closure(None)
242+
return (
243+
lambda value: [fallback_encoder(v) for v in value]
244+
if isinstance(value, (list, tuple))
245+
else value
105246
)
106-
for name in value._fields
107-
]
108247

109-
if isinstance(value, np.number):
110-
return value.item()
248+
# Handle dicts
249+
if isinstance(variant, AnalyzedDictType):
250+
if variant.value_type:
251+
value_encoder = _make_encoder_closure(_get_type_info(variant.value_type))
252+
untyped_encoder = _make_encoder_closure(None)
111253

112-
if isinstance(value, np.ndarray):
113-
return value
254+
def encode_dict(value: Any) -> Any:
255+
if not isinstance(value, dict):
256+
return value
114257

115-
if isinstance(value, (list, tuple)):
116-
if (
117-
type_info
118-
and isinstance(type_info.variant, AnalyzedListType)
119-
and type_info.variant.elem_type
120-
):
121-
elem_type_info = _get_type_info_safe(type_info.variant.elem_type)
122-
return [
123-
_encode_engine_value_core(
124-
v,
125-
type_info=elem_type_info,
126-
)
127-
for v in value
128-
]
258+
# Handle empty dict
259+
if not value:
260+
return []
261+
262+
# Handle KTable
263+
first_val = next(iter(value.values()))
264+
if is_struct_type(type(first_val)):
265+
return [
266+
[untyped_encoder(k)] + untyped_encoder(v)
267+
for k, v in value.items()
268+
]
269+
270+
# Handle regular dict
271+
return {k: value_encoder(v) for k, v in value.items()}
272+
273+
return encode_dict
129274
else:
130-
return [_encode_engine_value_core(v, type_info=None) for v in value]
131-
132-
if isinstance(value, dict):
133-
# Determine if this is a JSON type
134-
is_json_type = False
135-
if type_info and isinstance(type_info.variant, AnalyzedBasicType):
136-
is_json_type = type_info.variant.kind == "Json"
137-
138-
# Handle empty dict
139-
if not value:
140-
return value if (not type_info or is_json_type) else []
141-
142-
# Handle KTable
143-
first_val = next(iter(value.values()))
144-
if is_struct_type(type(first_val)):
145-
return [
146-
[_encode_engine_value_core(k, type_info=None)]
147-
+ _encode_engine_value_core(v, type_info=None)
148-
for k, v in value.items()
149-
]
275+
return lambda value: value
150276

151-
# Handle regular dict
152-
if (
153-
type_info
154-
and isinstance(type_info.variant, AnalyzedDictType)
155-
and type_info.variant.value_type
156-
):
157-
value_type_info = _get_type_info_safe(type_info.variant.value_type)
158-
return {
159-
k: _encode_engine_value_core(
160-
v,
161-
type_info=value_type_info,
277+
# Handle struct types
278+
if isinstance(variant, AnalyzedStructType):
279+
struct_type = variant.struct_type
280+
281+
if dataclasses.is_dataclass(struct_type):
282+
fields = dataclasses.fields(struct_type)
283+
field_encoders = [
284+
_make_encoder_closure(_get_type_info(f.type)) for f in fields
285+
]
286+
field_names = [f.name for f in fields]
287+
288+
def encode_dataclass(value: Any) -> Any:
289+
if not dataclasses.is_dataclass(value):
290+
return value
291+
return [
292+
encoder(getattr(value, name))
293+
for encoder, name in zip(field_encoders, field_names)
294+
]
295+
296+
return encode_dataclass
297+
298+
elif is_namedtuple_type(struct_type):
299+
annotations = struct_type.__annotations__
300+
field_names = list(getattr(struct_type, "_fields", ()))
301+
field_encoders = [
302+
_make_encoder_closure(
303+
_get_type_info(annotations.get(name))
304+
if annotations.get(name)
305+
else None
162306
)
163-
for k, v in value.items()
164-
}
307+
for name in field_names
308+
]
309+
310+
def encode_namedtuple(value: Any) -> Any:
311+
if not is_namedtuple_type(type(value)):
312+
return value
313+
return [
314+
encoder(getattr(value, name))
315+
for encoder, name in zip(field_encoders, field_names)
316+
]
317+
318+
return encode_namedtuple
319+
320+
# Handle numpy types
321+
def encode_with_numpy_check(value: Any) -> Any:
322+
if isinstance(value, np.number):
323+
return value.item()
324+
if isinstance(value, np.ndarray):
325+
return value
326+
if isinstance(value, tuple):
327+
return [_make_encoder_closure(None)(v) for v in value]
328+
return value
329+
330+
return encode_with_numpy_check
331+
332+
333+
def make_engine_value_encoder(type_hint: Type[Any] | str) -> Callable[[Any], Any]:
334+
"""
335+
Create an encoder closure for converting Python values to engine values.
336+
337+
Args:
338+
type_hint: Type annotation for the values to encode
165339
166-
return value
340+
Returns:
341+
A closure that encodes Python values to engine values
342+
"""
343+
type_info = _get_type_info(type_hint)
344+
if isinstance(type_info.variant, AnalyzedUnknownType):
345+
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")
346+
347+
return _make_encoder_closure(type_info)
167348

168349

169350
def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
@@ -177,12 +358,8 @@ def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
177358
Returns:
178359
The encoded engine value
179360
"""
180-
# Analyze type once and reuse the result
181-
type_info = _get_type_info_safe(type_hint)
182-
if isinstance(type_info.variant, AnalyzedUnknownType):
183-
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")
184-
185-
return _encode_engine_value_core(value, type_info)
361+
encoder = make_engine_value_encoder(type_hint)
362+
return encoder(value)
186363

187364

188365
def make_engine_value_decoder(

0 commit comments

Comments
 (0)