Skip to content

Commit c2f5c8c

Browse files
authored
fix: resolve forward ref in field types to be compatible with BAML (#1254)
1 parent 9d660b9 commit c2f5c8c

File tree

5 files changed

+131
-213
lines changed

5 files changed

+131
-213
lines changed

python/cocoindex/engine_object.py

Lines changed: 9 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from __future__ import annotations
66

77
import datetime
8-
import dataclasses
98
from enum import Enum
109
from typing import Any, Mapping, TypeVar, overload, get_origin
1110

@@ -24,18 +23,12 @@
2423
analyze_type_info,
2524
encode_enriched_type,
2625
is_namedtuple_type,
27-
is_pydantic_model,
2826
extract_ndarray_elem_dtype,
2927
)
3028

3129

3230
T = TypeVar("T")
3331

34-
try:
35-
import pydantic, pydantic_core
36-
except ImportError:
37-
pass
38-
3932

4033
def get_auto_default_for_type(
4134
type_info: AnalyzedTypeInfo,
@@ -175,67 +168,16 @@ def load_engine_object(expected_type: Any, v: Any) -> Any:
175168
if isinstance(variant, AnalyzedStructType):
176169
struct_type = variant.struct_type
177170
init_kwargs: dict[str, Any] = {}
178-
missing_fields: list[tuple[str, Any]] = []
179-
if dataclasses.is_dataclass(struct_type):
180-
if not isinstance(v, Mapping):
181-
raise ValueError(f"Expected dict for dataclass, got {type(v)}")
182-
183-
for dc_field in dataclasses.fields(struct_type):
184-
if dc_field.name in v:
185-
init_kwargs[dc_field.name] = load_engine_object(
186-
dc_field.type, v[dc_field.name]
187-
)
188-
else:
189-
if (
190-
dc_field.default is dataclasses.MISSING
191-
and dc_field.default_factory is dataclasses.MISSING
192-
):
193-
missing_fields.append((dc_field.name, dc_field.type))
194-
195-
elif is_namedtuple_type(struct_type):
196-
if not isinstance(v, Mapping):
197-
raise ValueError(f"Expected dict for NamedTuple, got {type(v)}")
198-
# Dict format (from dump/load functions)
199-
annotations = getattr(struct_type, "__annotations__", {})
200-
field_names = list(getattr(struct_type, "_fields", ()))
201-
field_defaults = getattr(struct_type, "_field_defaults", {})
202-
203-
for name in field_names:
204-
f_type = annotations.get(name, Any)
205-
if name in v:
206-
init_kwargs[name] = load_engine_object(f_type, v[name])
207-
elif name not in field_defaults:
208-
missing_fields.append((name, f_type))
209-
210-
elif is_pydantic_model(struct_type):
211-
if not isinstance(v, Mapping):
212-
raise ValueError(f"Expected dict for Pydantic model, got {type(v)}")
213-
214-
model_fields: dict[str, pydantic.fields.FieldInfo]
215-
if hasattr(struct_type, "model_fields"):
216-
model_fields = struct_type.model_fields # type: ignore[attr-defined]
171+
for field_info in variant.fields:
172+
if field_info.name in v:
173+
init_kwargs[field_info.name] = load_engine_object(
174+
field_info.type_hint, v[field_info.name]
175+
)
217176
else:
218-
model_fields = {}
219-
220-
for name, pyd_field in model_fields.items():
221-
if name in v:
222-
init_kwargs[name] = load_engine_object(
223-
pyd_field.annotation, v[name]
224-
)
225-
elif (
226-
getattr(pyd_field, "default", pydantic_core.PydanticUndefined)
227-
is pydantic_core.PydanticUndefined
228-
and getattr(pyd_field, "default_factory") is None
229-
):
230-
missing_fields.append((name, pyd_field.annotation))
231-
else:
232-
assert False, "Unsupported struct type"
233-
234-
for name, f_type in missing_fields:
235-
type_info = analyze_type_info(f_type)
236-
auto_default, is_supported = get_auto_default_for_type(type_info)
237-
if is_supported:
238-
init_kwargs[name] = auto_default
177+
type_info = analyze_type_info(field_info.type_hint)
178+
auto_default, is_supported = get_auto_default_for_type(type_info)
179+
if is_supported:
180+
init_kwargs[field_info.name] = auto_default
239181
return struct_type(**init_kwargs)
240182

241183
# Union with discriminator support via "kind"

python/cocoindex/engine_value.py

Lines changed: 30 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
from __future__ import annotations
66

7-
import dataclasses
87
import inspect
98
import warnings
10-
from typing import Any, Callable, Mapping, TypeVar
9+
from typing import Any, Callable, TypeVar
1110

1211
import numpy as np
1312
from .typing import (
@@ -19,8 +18,8 @@
1918
AnalyzedTypeInfo,
2019
AnalyzedUnionType,
2120
AnalyzedUnknownType,
21+
AnalyzedStructFieldInfo,
2222
analyze_type_info,
23-
is_namedtuple_type,
2423
is_pydantic_model,
2524
is_numpy_number_type,
2625
ValueType,
@@ -124,69 +123,20 @@ def encode_struct_dict(value: Any) -> Any:
124123
return encode_struct_dict
125124

126125
if isinstance(variant, AnalyzedStructType):
127-
struct_type = variant.struct_type
128-
129-
if dataclasses.is_dataclass(struct_type):
130-
fields = dataclasses.fields(struct_type)
131-
field_encoders = [
132-
make_engine_value_encoder(analyze_type_info(f.type)) for f in fields
133-
]
134-
field_names = [f.name for f in fields]
135-
136-
def encode_dataclass(value: Any) -> Any:
137-
if value is None:
138-
return None
139-
return [
140-
encoder(getattr(value, name))
141-
for encoder, name in zip(field_encoders, field_names)
142-
]
143-
144-
return encode_dataclass
145-
146-
elif is_namedtuple_type(struct_type):
147-
annotations = struct_type.__annotations__
148-
field_names = list(getattr(struct_type, "_fields", ()))
149-
field_encoders = [
150-
make_engine_value_encoder(
151-
analyze_type_info(annotations[name])
152-
if name in annotations
153-
else ANY_TYPE_INFO
154-
)
155-
for name in field_names
156-
]
157-
158-
def encode_namedtuple(value: Any) -> Any:
159-
if value is None:
160-
return None
161-
return [
162-
encoder(getattr(value, name))
163-
for encoder, name in zip(field_encoders, field_names)
164-
]
165-
166-
return encode_namedtuple
167-
168-
elif is_pydantic_model(struct_type):
169-
# Type guard: ensure we have model_fields attribute
170-
if hasattr(struct_type, "model_fields"):
171-
field_names = list(struct_type.model_fields.keys()) # type: ignore[attr-defined]
172-
field_encoders = [
173-
make_engine_value_encoder(
174-
analyze_type_info(struct_type.model_fields[name].annotation) # type: ignore[attr-defined]
175-
)
176-
for name in field_names
177-
]
178-
else:
179-
raise ValueError(f"Invalid Pydantic model: {struct_type}")
126+
field_encoders = [
127+
(
128+
field_info.name,
129+
make_engine_value_encoder(analyze_type_info(field_info.type_hint)),
130+
)
131+
for field_info in variant.fields
132+
]
180133

181-
def encode_pydantic(value: Any) -> Any:
182-
if value is None:
183-
return None
184-
return [
185-
encoder(getattr(value, name))
186-
for encoder, name in zip(field_encoders, field_names)
187-
]
134+
def encode_struct(value: Any) -> Any:
135+
if value is None:
136+
return None
137+
return [encoder(getattr(value, name)) for name, encoder in field_encoders]
188138

189-
return encode_pydantic
139+
return encode_struct
190140

191141
def encode_basic_value(value: Any) -> Any:
192142
if isinstance(value, np.number):
@@ -475,51 +425,12 @@ def make_engine_struct_decoder(
475425
src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)}
476426
dst_struct_type = dst_type_variant.struct_type
477427

478-
parameters: Mapping[str, inspect.Parameter]
479-
if dataclasses.is_dataclass(dst_struct_type):
480-
parameters = inspect.signature(dst_struct_type).parameters
481-
elif is_namedtuple_type(dst_struct_type):
482-
defaults = getattr(dst_struct_type, "_field_defaults", {})
483-
fields = getattr(dst_struct_type, "_fields", ())
484-
parameters = {
485-
name: inspect.Parameter(
486-
name=name,
487-
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
488-
default=defaults.get(name, inspect.Parameter.empty),
489-
annotation=dst_struct_type.__annotations__.get(
490-
name, inspect.Parameter.empty
491-
),
492-
)
493-
for name in fields
494-
}
495-
elif is_pydantic_model(dst_struct_type):
496-
# For Pydantic models, we can use model_fields to get field information
497-
parameters = {}
498-
# Type guard: ensure we have model_fields attribute
499-
if hasattr(dst_struct_type, "model_fields"):
500-
model_fields = dst_struct_type.model_fields # type: ignore[attr-defined]
501-
else:
502-
model_fields = {}
503-
for name, field_info in model_fields.items():
504-
default_value = (
505-
field_info.default
506-
if field_info.default is not ...
507-
else inspect.Parameter.empty
508-
)
509-
parameters[name] = inspect.Parameter(
510-
name=name,
511-
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
512-
default=default_value,
513-
annotation=field_info.annotation,
514-
)
515-
else:
516-
raise ValueError(f"Unsupported struct type: {dst_struct_type}")
517-
518428
def make_closure_for_field(
519-
name: str, param: inspect.Parameter
429+
field_info: AnalyzedStructFieldInfo,
520430
) -> Callable[[list[Any]], Any]:
431+
name = field_info.name
521432
src_idx = src_name_to_idx.get(name)
522-
type_info = analyze_type_info(param.annotation)
433+
type_info = analyze_type_info(field_info.type_hint)
523434

524435
with ChildFieldPath(field_path, f".{name}"):
525436
if src_idx is not None:
@@ -531,42 +442,44 @@ def make_closure_for_field(
531442
)
532443
return lambda values: field_decoder(values[src_idx])
533444

534-
default_value = param.default
445+
default_value = field_info.default_value
535446
if default_value is not inspect.Parameter.empty:
536447
return lambda _: default_value
537448

538449
auto_default, is_supported = get_auto_default_for_type(type_info)
539450
if is_supported:
540451
warnings.warn(
541-
f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
452+
f"Field '{name}' (type {field_info.type_hint}) without default value is missing in input: "
542453
f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
543454
UserWarning,
544455
stacklevel=4,
545456
)
546457
return lambda _: auto_default
547458

548459
raise ValueError(
549-
f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
460+
f"Field '{name}' (type {field_info.type_hint}) without default value is missing in input: {''.join(field_path)}"
550461
)
551462

552-
field_value_decoder = [
553-
make_closure_for_field(name, param) for (name, param) in parameters.items()
554-
]
555-
556463
# Different construction for different struct types
557464
if is_pydantic_model(dst_struct_type):
558465
# Pydantic models prefer keyword arguments
559-
field_names = list(parameters.keys())
466+
pydantic_fields_decoder = [
467+
(field_info.name, make_closure_for_field(field_info))
468+
for field_info in dst_type_variant.fields
469+
]
560470
return lambda values: dst_struct_type(
561471
**{
562-
field_names[i]: decoder(values)
563-
for i, decoder in enumerate(field_value_decoder)
472+
field_name: decoder(values)
473+
for field_name, decoder in pydantic_fields_decoder
564474
}
565475
)
566476
else:
477+
struct_fields_decoder = [
478+
make_closure_for_field(field_info) for field_info in dst_type_variant.fields
479+
]
567480
# Dataclasses and NamedTuples can use positional arguments
568481
return lambda values: dst_struct_type(
569-
*(decoder(values) for decoder in field_value_decoder)
482+
*(decoder(values) for decoder in struct_fields_decoder)
570483
)
571484

572485

python/cocoindex/op.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
StructSchema,
3838
StructType,
3939
TableType,
40-
TypeAttr,
4140
encode_enriched_type_info,
4241
resolve_forward_ref,
4342
analyze_type_info,

python/cocoindex/tests/test_engine_value.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,3 +1690,35 @@ class MixedStruct:
16901690
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
16911691
mixed = MixedStruct(name="test", pydantic_order=order)
16921692
validate_full_roundtrip(mixed, MixedStruct)
1693+
1694+
1695+
def test_forward_ref_in_dataclass() -> None:
1696+
"""Test mixing Pydantic models with dataclasses."""
1697+
1698+
@dataclass
1699+
class Event:
1700+
name: "str"
1701+
tag: "Tag"
1702+
1703+
validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event)
1704+
1705+
1706+
def test_forward_ref_in_namedtuple() -> None:
1707+
"""Test mixing Pydantic models with dataclasses."""
1708+
1709+
class Event(NamedTuple):
1710+
name: "str"
1711+
tag: "Tag"
1712+
1713+
validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event)
1714+
1715+
1716+
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1717+
def test_forward_ref_in_pydantic() -> None:
1718+
"""Test mixing Pydantic models with dataclasses."""
1719+
1720+
class Event(BaseModel):
1721+
name: "str"
1722+
tag: "Tag"
1723+
1724+
validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event)

0 commit comments

Comments
 (0)