Skip to content

Commit 9e4dc41

Browse files
committed
test: add example failing test for recursive datatype DDLs
1 parent b4fe284 commit 9e4dc41

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

duckdb_engine/tests/test_datatypes.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import decimal
2+
import enum
23
import json
34
import warnings
45
from typing import Any, Dict, Type
@@ -9,6 +10,7 @@
910
from snapshottest.module import SnapshotTest
1011
from sqlalchemy import (
1112
Column,
13+
Enum,
1214
Integer,
1315
Interval,
1416
MetaData,
@@ -28,7 +30,7 @@
2830
from sqlalchemy.types import FLOAT, JSON
2931

3032
from .._supports import duckdb_version, has_uhugeint_support
31-
from ..datatypes import Map, Struct, types
33+
from ..datatypes import Map, Struct, Union, types
3234

3335

3436
@mark.parametrize("coltype", types)
@@ -234,6 +236,47 @@ class Entry(base):
234236
assert result.outer == outer
235237

236238

239+
def test_double_nested_type_ddl(engine: Engine, session: Session) -> None:
240+
"""If we create a table with a nested type, then all the children types,
241+
(such as ENUMS) need to have already been created with an eg CREATE TYPE ...
242+
DDL statement."""
243+
importorskip("duckdb", "0.5.0") # nested types require at least duckdb 0.5.0
244+
base = declarative_base()
245+
246+
class Severity(enum.Enum):
247+
LOW = "L"
248+
MEDIUM = "M"
249+
HIGH = "H"
250+
251+
class Entry(base):
252+
__tablename__ = "test_struct"
253+
254+
id = Column(Integer, primary_key=True, default=0)
255+
struct = Column(Struct({"severity": Enum(Severity)}))
256+
map = Column(Map(String, Enum(Severity)))
257+
union = Column(Union({"age": Integer, "severity": Enum(Severity)}))
258+
259+
base.metadata.create_all(bind=engine)
260+
261+
struct = {"struct": {"severity": "L"}}
262+
session.add(Entry(struct=struct)) # type: ignore[call-arg]
263+
session.commit()
264+
result = session.query(Entry).one()
265+
assert result.struct == struct
266+
267+
map = {"one": "L", "two": "M"}
268+
session.add(Entry(map=map)) # type: ignore[call-arg]
269+
session.commit()
270+
result = session.query(Entry).one()
271+
assert result.map == map
272+
273+
union = {"age": 42}
274+
session.add(Entry(union=union)) # type: ignore[call-arg]
275+
session.commit()
276+
result = session.query(Entry).one()
277+
assert result.union == union
278+
279+
237280
def test_interval(engine: Engine, snapshot: SnapshotTest) -> None:
238281
test_table = Table("test_table", MetaData(), Column("duration", Interval))
239282

0 commit comments

Comments
 (0)