Skip to content

Commit 5d45115

Browse files
feat: PostgresCatalog and PostgresTable followups (#5508)
## Changes Made From user feedback: - Enable "vector" extension by default, if available - Make Row Level Security opt-out, following Supabase behaviour - Add docs - Cast `vector` to `embedding` correctly on `read_sql()`
1 parent 670eecc commit 5d45115

File tree

6 files changed

+418
-28
lines changed

6 files changed

+418
-28
lines changed

daft/catalog/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,21 +313,21 @@ def from_glue(
313313
raise ImportError("AWS Glue support not installed: pip install -U 'daft[aws]'")
314314

315315
@staticmethod
316-
def from_postgres(connection_string: str, extensions: list[str] | None = None) -> Catalog:
316+
def from_postgres(connection_string: str, extensions: list[str] | None = ["vector"]) -> Catalog:
317317
"""Create a Daft Catalog from a PostgreSQL connection string.
318318
319-
Note::
320-
This is an experimental feature and the API may change in the future.
321-
322319
Args:
323320
connection_string (str): a PostgreSQL connection string
324321
extensions (list[str], optional): List of PostgreSQL extensions to create if they don't exist.
325322
For each extension, "CREATE EXTENSION IF NOT EXISTS <extension>" will be executed.
326-
Defaults to None (no extensions).
323+
Defaults to ["vector"] (pgvector extension, if available).
327324
328325
Returns:
329326
Catalog: a new Catalog instance to a PostgreSQL database.
330327
328+
Warning:
329+
This features is early in development and will likely experience API changes.
330+
331331
Examples:
332332
>>> catalog = Catalog.from_postgres("postgresql://user:password@host:port/database")
333333
>>> catalog = Catalog.from_postgres(

daft/catalog/__postgres.py

Lines changed: 150 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,47 @@
1313
from daft.datatype import DataType
1414
from daft.expressions import col
1515
from daft.io._sql import read_sql
16+
from daft.logical.schema import Field
1617

1718

1819
@contextmanager
19-
def postgres_connection(connection_string: str, extensions: list[str] | None = None) -> psycopg.Connection.connect:
20+
def postgres_connection(connection_string: str, extensions: list[str] | None) -> psycopg.Connection.connect:
2021
"""Context manager that provides a PostgreSQL connection with specified extensions setup.
2122
2223
Args:
2324
connection_string: PostgreSQL connection string
24-
extensions: List of extension names to create if they don't exist. For each extension,
25-
"CREATE EXTENSION IF NOT EXISTS <extension>" will be executed.
25+
extensions: List of extension names to create if they don't exist and are available.
26+
For each extension, availability is checked in pg_available_extensions before
27+
attempting "CREATE EXTENSION IF NOT EXISTS <extension>".
2628
"""
2729
with psycopg.connect(connection_string) as conn:
2830
if extensions:
29-
for extension in extensions:
30-
conn.execute(
31-
psycopg.sql.SQL("CREATE EXTENSION IF NOT EXISTS {}").format(psycopg.sql.Identifier(extension))
32-
)
33-
# Special handling for vector extension - register pgvector types.
34-
if "vector" in extensions:
35-
register_vector(conn)
31+
with conn.cursor() as cur:
32+
for extension in extensions:
33+
# Check if extension is available before attempting to create it
34+
cur.execute(
35+
psycopg.sql.SQL("SELECT EXISTS(SELECT 1 FROM pg_available_extensions WHERE name = {})").format(
36+
psycopg.sql.Literal(extension)
37+
)
38+
)
39+
result = cur.fetchone()
40+
is_available = result[0] if result else False
41+
42+
if is_available:
43+
cur.execute(
44+
psycopg.sql.SQL("CREATE EXTENSION IF NOT EXISTS {}").format(
45+
psycopg.sql.Identifier(extension)
46+
)
47+
)
48+
49+
# Register pgvector type if it was successfully created
50+
if "vector" in extensions:
51+
cur.execute(psycopg.sql.SQL("SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'vector')"))
52+
result = cur.fetchone()
53+
vector_installed = result[0] if result else False
54+
55+
if vector_installed:
56+
register_vector(conn)
3657
yield conn
3758

3859

@@ -187,7 +208,7 @@ def __init__(self) -> None:
187208
raise RuntimeError("PostgresCatalog.__init__ is not supported, please use `Catalog.from_postgres` instead.")
188209

189210
@staticmethod
190-
def from_uri(uri: str, extensions: list[str] | None = None, **options: str | None) -> PostgresCatalog:
211+
def from_uri(uri: str, extensions: list[str] | None, **options: str | None) -> PostgresCatalog:
191212
"""Create a PostgresCatalog from a connection string."""
192213
validate_connection_string(uri)
193214
c = PostgresCatalog.__new__(PostgresCatalog)
@@ -236,7 +257,7 @@ def _create_table(
236257
Args:
237258
identifier (Identifier): The identifier of the table to create.
238259
schema (Schema): The schema of the table to create.
239-
properties (Properties): The properties of the table to create. One supported property is "enable_rls" (bool), which enables Row Level Security by default. See: https://www.postgresql.org/docs/current/ddl-rowsecurity.html
260+
properties (Properties): The properties of the table to create. One supported property is "enable_rls" (bool), which enables Row Level Security. This property is set to True by default. See: https://www.postgresql.org/docs/current/ddl-rowsecurity.html
240261
partition_fields (list[PartitionField]): The partition fields of the table to create.
241262
242263
Returns:
@@ -280,7 +301,7 @@ def _create_table(
280301
)
281302
)
282303

283-
if properties and properties.get("enable_rls", False):
304+
if properties is None or properties.get("enable_rls", True):
284305
cur.execute(
285306
psycopg.sql.SQL("ALTER TABLE {} ENABLE ROW LEVEL SECURITY").format(quoted_full_table)
286307
)
@@ -480,7 +501,113 @@ def name(self) -> str:
480501

481502
def schema(self) -> Schema:
482503
"""Returns the table's schema."""
483-
return self.read().schema()
504+
connection_string, identifier = self._inner
505+
506+
if len(identifier) == 1:
507+
# When no schema is specified, PostgreSQL uses the schema search path to select the schema to use.
508+
# Since this is user-configurable, we simply pass along the single identifier to PostgreSQL.
509+
# See: https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH
510+
schema_name = None
511+
table_name = identifier[0]
512+
elif len(identifier) == 2:
513+
schema_name = identifier[0]
514+
table_name = identifier[1]
515+
else:
516+
raise ValueError(f"Invalid table identifier: {identifier}")
517+
518+
# Query the database schema to get column information
519+
with postgres_connection(connection_string, self._extensions) as conn:
520+
with conn.cursor() as cur:
521+
if schema_name:
522+
cur.execute(
523+
psycopg.sql.SQL("""
524+
SELECT
525+
c.column_name,
526+
c.data_type,
527+
c.udt_name,
528+
CASE
529+
WHEN c.data_type = 'USER-DEFINED' AND c.udt_name = 'vector'
530+
THEN a.atttypmod
531+
ELSE NULL
532+
END as vector_dimension
533+
FROM information_schema.columns c
534+
JOIN pg_class cls ON cls.relname = c.table_name
535+
JOIN pg_namespace nsp ON nsp.oid = cls.relnamespace AND nsp.nspname = c.table_schema
536+
LEFT JOIN pg_attribute a ON a.attrelid = cls.oid AND a.attname = c.column_name
537+
WHERE c.table_schema = {} AND c.table_name = {}
538+
ORDER BY c.ordinal_position
539+
""").format(psycopg.sql.Literal(schema_name), psycopg.sql.Literal(table_name)),
540+
)
541+
else:
542+
cur.execute(
543+
psycopg.sql.SQL("""
544+
SELECT
545+
c.column_name,
546+
c.data_type,
547+
c.udt_name,
548+
CASE
549+
WHEN c.data_type = 'USER-DEFINED' AND c.udt_name = 'vector'
550+
THEN a.atttypmod
551+
ELSE NULL
552+
END as vector_dimension
553+
FROM information_schema.columns c
554+
JOIN pg_class cls ON cls.relname = c.table_name
555+
LEFT JOIN pg_attribute a ON a.attrelid = cls.oid AND a.attname = c.column_name
556+
WHERE c.table_name = {}
557+
ORDER BY c.ordinal_position
558+
""").format(psycopg.sql.Literal(table_name)),
559+
)
560+
561+
columns = cur.fetchall()
562+
563+
# If no columns found, fall back to data-based inference
564+
if not columns:
565+
return self.read().schema()
566+
567+
# Build schema from database metadata
568+
fields = []
569+
for column_name, data_type, udt_name, vector_dimension in columns:
570+
if data_type == "USER-DEFINED" and udt_name == "vector":
571+
# This is a pgvector column, convert to embedding type
572+
# vector_dimension from atttypmod contains the dimension information
573+
# For pgvector, atttypmod stores the dimension directly
574+
dimension = vector_dimension if vector_dimension and vector_dimension > 0 else 0
575+
576+
if dimension > 0:
577+
fields.append(Field.create(column_name, DataType.embedding(DataType.float32(), dimension)))
578+
else:
579+
# Fallback to list if we can't determine dimension
580+
fields.append(Field.create(column_name, DataType.list(DataType.float32())))
581+
else:
582+
# For non-vector columns, try direct PostgreSQL type mapping first
583+
try:
584+
# Attempt to map PostgreSQL type directly to Daft type
585+
inferred_dtype = DataType.from_sql(data_type)
586+
fields.append(Field.create(column_name, inferred_dtype))
587+
except Exception:
588+
# Fall back to data-based inference for unmappable types
589+
# This is inefficient but ensures we get the correct types
590+
if schema_name:
591+
single_col_query = psycopg.sql.SQL("SELECT {} FROM {}.{} LIMIT 1").format(
592+
psycopg.sql.Identifier(column_name),
593+
psycopg.sql.Identifier(schema_name),
594+
psycopg.sql.Identifier(table_name),
595+
)
596+
else:
597+
single_col_query = psycopg.sql.SQL("SELECT {} FROM {} LIMIT 1").format(
598+
psycopg.sql.Identifier(column_name),
599+
psycopg.sql.Identifier(table_name),
600+
)
601+
602+
try:
603+
single_col_df = read_sql(single_col_query.as_string(), connection_string)
604+
inferred_dtype = single_col_df.schema()[column_name].dtype
605+
fields.append(Field.create(column_name, inferred_dtype))
606+
except Exception:
607+
# If inference fails, use string as fallback
608+
fields.append(Field.create(column_name, DataType.string()))
609+
610+
return Schema._from_fields(fields)
484611

485612
@staticmethod
486613
def _from_obj(obj: object) -> PostgresTable:
@@ -510,12 +637,20 @@ def read(
510637

511638
query = psycopg.sql.SQL("SELECT * FROM {}").format(quoted_full_table)
512639

513-
return read_sql(
640+
df = read_sql(
514641
query.as_string(),
515642
connection_string,
516643
**options,
517644
)
518645

646+
# Cast any vector columns that were read as lists to embeddings
647+
schema = self.schema() # Use our custom schema method
648+
for field in schema:
649+
if field.dtype.is_embedding():
650+
df = df.with_column(field.name, df[field.name].cast(field.dtype))
651+
652+
return df
653+
519654
def append(self, df: DataFrame, **options: Any) -> None:
520655
"""Append the DataFrame to the table."""
521656
connection_string, identifier = self._inner

docs/SUMMARY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
* [Delta Lake](connectors/delta_lake.md)
3636
* [Lance](connectors/lance.md)
3737
* [Hugging Face Datasets](connectors/huggingface.md)
38+
* [Postgres](connectors/postgres.md)
3839
* [S3](connectors/aws.md)
3940
* [SQL Databases](connectors/sql.md)
4041
* [Turbopuffer](connectors/turbopuffer.md)

docs/connectors/index.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ See also [Lance](lance.md) for detailed integration.
7373
| [`write_parquet`][daft.dataframe.DataFrame.write_parquet] | Write a DataFrame to Parquet files |
7474

7575

76+
## PostgreSQL
77+
78+
| Function | Description |
79+
|---------------------------------------------------------|---------------------------------------------------|
80+
| [`Catalog.from_postgres`][daft.catalog.Catalog.from_postgres] | Create a catalog from a PostgreSQL database |
81+
82+
See also [PostgreSQL](postgres.md) for detailed integration.
83+
7684
## SQL
7785

7886
| Function | Description |

0 commit comments

Comments
 (0)