|
1 | 1 | import decimal |
| 2 | +import enum |
2 | 3 | import json |
3 | 4 | import warnings |
4 | 5 | from typing import Any, Dict, Type |
|
9 | 10 | from snapshottest.module import SnapshotTest |
10 | 11 | from sqlalchemy import ( |
11 | 12 | Column, |
| 13 | + Enum, |
12 | 14 | Integer, |
13 | 15 | Interval, |
14 | 16 | MetaData, |
|
28 | 30 | from sqlalchemy.types import FLOAT, JSON |
29 | 31 |
|
30 | 32 | from .._supports import duckdb_version, has_uhugeint_support |
31 | | -from ..datatypes import Map, Struct, types |
| 33 | +from ..datatypes import Map, Struct, Union, types |
32 | 34 |
|
33 | 35 |
|
34 | 36 | @mark.parametrize("coltype", types) |
@@ -234,6 +236,47 @@ class Entry(base): |
234 | 236 | assert result.outer == outer |
235 | 237 |
|
236 | 238 |
|
| 239 | +def test_double_nested_type_ddl(engine: Engine, session: Session) -> None: |
| 240 | + """If we create a table with a nested type, then all the children types, |
| 241 | + (such as ENUMS) need to have already been created with an eg CREATE TYPE ... |
| 242 | + DDL statement.""" |
| 243 | + importorskip("duckdb", "0.5.0") # nested types require at least duckdb 0.5.0 |
| 244 | + base = declarative_base() |
| 245 | + |
| 246 | + class Severity(enum.Enum): |
| 247 | + LOW = "L" |
| 248 | + MEDIUM = "M" |
| 249 | + HIGH = "H" |
| 250 | + |
| 251 | + class Entry(base): |
| 252 | + __tablename__ = "test_struct" |
| 253 | + |
| 254 | + id = Column(Integer, primary_key=True, default=0) |
| 255 | + struct = Column(Struct({"severity": Enum(Severity)})) |
| 256 | + map = Column(Map(String, Enum(Severity))) |
| 257 | + union = Column(Union({"age": Integer, "severity": Enum(Severity)})) |
| 258 | + |
| 259 | + base.metadata.create_all(bind=engine) |
| 260 | + |
| 261 | + struct = {"struct": {"severity": "L"}} |
| 262 | + session.add(Entry(struct=struct)) # type: ignore[call-arg] |
| 263 | + session.commit() |
| 264 | + result = session.query(Entry).one() |
| 265 | + assert result.struct == struct |
| 266 | + |
| 267 | + map = {"one": "L", "two": "M"} |
| 268 | + session.add(Entry(map=map)) # type: ignore[call-arg] |
| 269 | + session.commit() |
| 270 | + result = session.query(Entry).one() |
| 271 | + assert result.map == map |
| 272 | + |
| 273 | + union = {"age": 42} |
| 274 | + session.add(Entry(union=union)) # type: ignore[call-arg] |
| 275 | + session.commit() |
| 276 | + result = session.query(Entry).one() |
| 277 | + assert result.union == union |
| 278 | + |
| 279 | + |
237 | 280 | def test_interval(engine: Engine, snapshot: SnapshotTest) -> None: |
238 | 281 | test_table = Table("test_table", MetaData(), Column("duration", Interval)) |
239 | 282 |
|
|
0 commit comments