Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
PGTypeCompiler,
)
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.dialects.postgresql.types import OID, REGCLASS
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.interfaces import Dialect as RootDialect
from sqlalchemy.engine.reflection import cache
Expand Down Expand Up @@ -564,6 +565,159 @@ def _pg_class_filter_scope_schema(
)
return query

def _columns_query( # type: ignore[no-untyped-def]
self, schema, has_filter_names, scope, kind
):
"""
Override of SQLAlchemy's PostgreSQL dialect _columns_query method.

This method is customized for DuckDB to avoid querying pg_collation,
which doesn't exist in DuckDB's catalog. We simply return NULL for
the collation field instead of trying to fetch it from pg_collation.
"""
from sqlalchemy.dialects.postgresql import ( # type: ignore[attr-defined]
pg_catalog,
)

# NOTE: the query with the default and identity options scalar
# subquery is faster than trying to use outer joins for them
generated = (
pg_catalog.pg_attribute.c.attgenerated.label("generated")
if self.server_version_info >= (12,)
else sql.null().label("generated")
)
if self.server_version_info >= (10,):
# join lateral performs worse (~2x slower) than a scalar_subquery
# also the subquery can be run only if the column is an identity
identity = sql.case(
( # attidentity != '' is required or it will reflect also
# serial columns as identity.
pg_catalog.pg_attribute.c.attidentity != "",
select(
sql.func.json_build_object(
"always",
pg_catalog.pg_attribute.c.attidentity == "a",
"start",
pg_catalog.pg_sequence.c.seqstart,
"increment",
pg_catalog.pg_sequence.c.seqincrement,
"minvalue",
pg_catalog.pg_sequence.c.seqmin,
"maxvalue",
pg_catalog.pg_sequence.c.seqmax,
"cache",
pg_catalog.pg_sequence.c.seqcache,
"cycle",
pg_catalog.pg_sequence.c.seqcycle,
type_=sqltypes.JSON(),
)
)
.select_from(pg_catalog.pg_sequence)
.where(
# not needed but pg seems to like it
pg_catalog.pg_attribute.c.attidentity != "",
pg_catalog.pg_sequence.c.seqrelid
== sql.cast(
sql.cast(
pg_catalog.pg_get_serial_sequence(
sql.cast(
sql.cast(
pg_catalog.pg_attribute.c.attrelid,
REGCLASS,
),
sqltypes.TEXT,
),
pg_catalog.pg_attribute.c.attname,
),
REGCLASS,
),
OID,
),
)
.correlate(pg_catalog.pg_attribute)
.scalar_subquery(),
),
else_=sql.null(),
).label("identity_options")
else:
identity = sql.null().label("identity_options")

# join lateral performs the same as scalar_subquery here, also
# the subquery can be run only if the column has a default
default = sql.case(
(
pg_catalog.pg_attribute.c.atthasdef,
select(
pg_catalog.pg_get_expr(
pg_catalog.pg_attrdef.c.adbin,
pg_catalog.pg_attrdef.c.adrelid,
)
)
.select_from(pg_catalog.pg_attrdef)
.where(
# not needed but pg seems to like it
pg_catalog.pg_attribute.c.atthasdef,
pg_catalog.pg_attrdef.c.adrelid
== pg_catalog.pg_attribute.c.attrelid,
pg_catalog.pg_attrdef.c.adnum == pg_catalog.pg_attribute.c.attnum,
)
.correlate(pg_catalog.pg_attribute)
.scalar_subquery(),
),
else_=sql.null(),
).label("default")

# DuckDB doesn't have pg_collation table, so we return NULL instead
# of trying to query it like the PostgreSQL dialect does
collate = sql.null().label("collation")

relkinds = self._kind_to_relkinds(kind) # type: ignore[attr-defined]
query = (
select(
pg_catalog.pg_attribute.c.attname.label("name"),
pg_catalog.format_type(
pg_catalog.pg_attribute.c.atttypid,
pg_catalog.pg_attribute.c.atttypmod,
).label("format_type"),
default,
pg_catalog.pg_attribute.c.attnotnull.label("not_null"),
pg_catalog.pg_class.c.relname.label("table_name"),
pg_catalog.pg_description.c.description.label("comment"),
generated,
identity,
collate,
)
.select_from(pg_catalog.pg_class)
# NOTE: postgresql support table with no user column, meaning
# there is no row with pg_attribute.attnum > 0. use a left outer
# join to avoid filtering these tables.
.outerjoin(
pg_catalog.pg_attribute,
sql.and_(
pg_catalog.pg_class.c.oid == pg_catalog.pg_attribute.c.attrelid,
pg_catalog.pg_attribute.c.attnum > 0,
~pg_catalog.pg_attribute.c.attisdropped,
),
)
.outerjoin(
pg_catalog.pg_description,
sql.and_(
pg_catalog.pg_description.c.objoid
== pg_catalog.pg_attribute.c.attrelid,
pg_catalog.pg_description.c.objsubid
== pg_catalog.pg_attribute.c.attnum,
),
)
.where(self._pg_class_relkind_condition(relkinds)) # type: ignore[attr-defined]
.order_by(pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum)
)
query = self._pg_class_filter_scope_schema(query, schema, scope=scope) # type: ignore[attr-defined]
if has_filter_names:
query = query.where(
pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
)
return query

# FIXME: this method is a hack around the fact that we use a single cursor for all queries inside a connection,
# and this is required to fix get_multi_columns
def get_multi_columns(
Expand Down