|
1 | 1 | import dataclasses |
2 | 2 | import uuid |
3 | 3 | import datetime |
4 | | -from dataclasses import dataclass |
| 4 | +from dataclasses import dataclass, make_dataclass |
5 | 5 | import pytest |
6 | 6 | from cocoindex.typing import encode_enriched_type |
7 | 7 | from cocoindex.convert import to_engine_value |
@@ -217,3 +217,65 @@ class PythonOrder: |
217 | 217 | converter = build_engine_value_converter(EngineOrder, PythonOrder) |
218 | 218 | engine_val = ["O123", "mixed nuts", 25.0] |
219 | 219 | assert converter(engine_val) == PythonOrder("O123", "mixed nuts") |
| 220 | + |
| 221 | + |
| 222 | + |
| 223 | +def make_engine_order(fields): |
| 224 | + return make_dataclass('EngineOrder', fields) |
| 225 | + |
| 226 | +def make_python_order(fields, defaults=None): |
| 227 | + if defaults is None: |
| 228 | + defaults = {} |
| 229 | + # Move all fields with defaults to the end (Python dataclass requirement) |
| 230 | + non_default_fields = [(n, t) for n, t in fields if n not in defaults] |
| 231 | + default_fields = [(n, t) for n, t in fields if n in defaults] |
| 232 | + ordered_fields = non_default_fields + default_fields |
| 233 | + # Prepare the namespace for defaults (only for fields at the end) |
| 234 | + namespace = {k: defaults[k] for k, _ in default_fields} |
| 235 | + return make_dataclass('PythonOrder', ordered_fields, namespace=namespace) |
| 236 | + |
| 237 | +@pytest.mark.parametrize( |
| 238 | + "engine_fields, python_fields, python_defaults, engine_val, expected_python_val", |
| 239 | + [ |
| 240 | + # Extra field in Python (middle) |
| 241 | + ( |
| 242 | + [("id", str), ("name", str)], |
| 243 | + [("id", str), ("price", float), ("name", str)], |
| 244 | + {"price": 0.0}, |
| 245 | + ["O123", "mixed nuts"], |
| 246 | + ("O123", 0.0, "mixed nuts"), |
| 247 | + ), |
| 248 | + # Missing field in Python (middle) |
| 249 | + ( |
| 250 | + [("id", str), ("price", float), ("name", str)], |
| 251 | + [("id", str), ("name", str)], |
| 252 | + {}, |
| 253 | + ["O123", 25.0, "mixed nuts"], |
| 254 | + ("O123", "mixed nuts"), |
| 255 | + ), |
| 256 | + # Extra field in Python (start) |
| 257 | + ( |
| 258 | + [("name", str), ("price", float)], |
| 259 | + [("extra", str), ("name", str), ("price", float)], |
| 260 | + {"extra": "default"}, |
| 261 | + ["mixed nuts", 25.0], |
| 262 | + ("default", "mixed nuts", 25.0), |
| 263 | + ), |
| 264 | + # Missing field in Python (start) |
| 265 | + ( |
| 266 | + [("extra", str), ("name", str), ("price", float)], |
| 267 | + [("name", str), ("price", float)], |
| 268 | + {}, |
| 269 | + ["unexpected", "mixed nuts", 25.0], |
| 270 | + ("mixed nuts", 25.0), |
| 271 | + ), |
| 272 | + ] |
| 273 | +) |
| 274 | +def test_field_position_cases(engine_fields, python_fields, python_defaults, engine_val, expected_python_val): |
| 275 | + EngineOrder = make_engine_order(engine_fields) |
| 276 | + PythonOrder = make_python_order(python_fields, python_defaults) |
| 277 | + converter = build_engine_value_converter(EngineOrder, PythonOrder) |
| 278 | + # Map field names to expected values |
| 279 | + expected_dict = dict(zip([f[0] for f in python_fields], expected_python_val)) |
| 280 | + # Instantiate using keyword arguments (order doesn't matter) |
| 281 | + assert converter(engine_val) == PythonOrder(**expected_dict) |
0 commit comments