|
13 | 13 | from daft.datatype import DataType |
14 | 14 | from daft.expressions import col |
15 | 15 | from daft.io._sql import read_sql |
| 16 | +from daft.logical.schema import Field |
16 | 17 |
|
17 | 18 |
|
18 | 19 | @contextmanager |
19 | | -def postgres_connection(connection_string: str, extensions: list[str] | None = None) -> psycopg.Connection.connect: |
| 20 | +def postgres_connection(connection_string: str, extensions: list[str] | None) -> psycopg.Connection.connect: |
20 | 21 | """Context manager that provides a PostgreSQL connection with specified extensions setup. |
21 | 22 |
|
22 | 23 | Args: |
23 | 24 | connection_string: PostgreSQL connection string |
24 | | - extensions: List of extension names to create if they don't exist. For each extension, |
25 | | - "CREATE EXTENSION IF NOT EXISTS <extension>" will be executed. |
| 25 | + extensions: List of extension names to create if they don't exist and are available. |
| 26 | + For each extension, availability is checked in pg_available_extensions before |
| 27 | + attempting "CREATE EXTENSION IF NOT EXISTS <extension>". |
26 | 28 | """ |
27 | 29 | with psycopg.connect(connection_string) as conn: |
28 | 30 | if extensions: |
29 | | - for extension in extensions: |
30 | | - conn.execute( |
31 | | - psycopg.sql.SQL("CREATE EXTENSION IF NOT EXISTS {}").format(psycopg.sql.Identifier(extension)) |
32 | | - ) |
33 | | - # Special handling for vector extension - register pgvector types. |
34 | | - if "vector" in extensions: |
35 | | - register_vector(conn) |
| 31 | + with conn.cursor() as cur: |
| 32 | + for extension in extensions: |
| 33 | + # Check if extension is available before attempting to create it |
| 34 | + cur.execute( |
| 35 | + psycopg.sql.SQL("SELECT EXISTS(SELECT 1 FROM pg_available_extensions WHERE name = {})").format( |
| 36 | + psycopg.sql.Literal(extension) |
| 37 | + ) |
| 38 | + ) |
| 39 | + result = cur.fetchone() |
| 40 | + is_available = result[0] if result else False |
| 41 | + |
| 42 | + if is_available: |
| 43 | + cur.execute( |
| 44 | + psycopg.sql.SQL("CREATE EXTENSION IF NOT EXISTS {}").format( |
| 45 | + psycopg.sql.Identifier(extension) |
| 46 | + ) |
| 47 | + ) |
| 48 | + |
| 49 | + # Register pgvector type if it was successfully created |
| 50 | + if "vector" in extensions: |
| 51 | + cur.execute(psycopg.sql.SQL("SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'vector')")) |
| 52 | + result = cur.fetchone() |
| 53 | + vector_installed = result[0] if result else False |
| 54 | + |
| 55 | + if vector_installed: |
| 56 | + register_vector(conn) |
36 | 57 | yield conn |
37 | 58 |
|
38 | 59 |
|
@@ -187,7 +208,7 @@ def __init__(self) -> None: |
187 | 208 | raise RuntimeError("PostgresCatalog.__init__ is not supported, please use `Catalog.from_postgres` instead.") |
188 | 209 |
|
189 | 210 | @staticmethod |
190 | | - def from_uri(uri: str, extensions: list[str] | None = None, **options: str | None) -> PostgresCatalog: |
| 211 | + def from_uri(uri: str, extensions: list[str] | None, **options: str | None) -> PostgresCatalog: |
191 | 212 | """Create a PostgresCatalog from a connection string.""" |
192 | 213 | validate_connection_string(uri) |
193 | 214 | c = PostgresCatalog.__new__(PostgresCatalog) |
@@ -236,7 +257,7 @@ def _create_table( |
236 | 257 | Args: |
237 | 258 | identifier (Identifier): The identifier of the table to create. |
238 | 259 | schema (Schema): The schema of the table to create. |
239 | | - properties (Properties): The properties of the table to create. One supported property is "enable_rls" (bool), which enables Row Level Security by default. See: https://www.postgresql.org/docs/current/ddl-rowsecurity.html |
| 260 | + properties (Properties): The properties of the table to create. One supported property is "enable_rls" (bool), which enables Row Level Security. This property is set to True by default. See: https://www.postgresql.org/docs/current/ddl-rowsecurity.html |
240 | 261 | partition_fields (list[PartitionField]): The partition fields of the table to create. |
241 | 262 |
|
242 | 263 | Returns: |
@@ -280,7 +301,7 @@ def _create_table( |
280 | 301 | ) |
281 | 302 | ) |
282 | 303 |
|
283 | | - if properties and properties.get("enable_rls", False): |
| 304 | + if properties is None or properties.get("enable_rls", True): |
284 | 305 | cur.execute( |
285 | 306 | psycopg.sql.SQL("ALTER TABLE {} ENABLE ROW LEVEL SECURITY").format(quoted_full_table) |
286 | 307 | ) |
@@ -480,7 +501,113 @@ def name(self) -> str: |
480 | 501 |
|
481 | 502 | def schema(self) -> Schema: |
482 | 503 | """Returns the table's schema.""" |
483 | | - return self.read().schema() |
| 504 | + connection_string, identifier = self._inner |
| 505 | + |
| 506 | + if len(identifier) == 1: |
| 507 | + # When no schema is specified, PostgreSQL uses the schema search path to select the schema to use. |
| 508 | + # Since this is user-configurable, we simply pass along the single identifier to PostgreSQL. |
| 509 | + # See: https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH |
| 510 | + schema_name = None |
| 511 | + table_name = identifier[0] |
| 512 | + elif len(identifier) == 2: |
| 513 | + schema_name = identifier[0] |
| 514 | + table_name = identifier[1] |
| 515 | + else: |
| 516 | + raise ValueError(f"Invalid table identifier: {identifier}") |
| 517 | + |
| 518 | + # Query the database schema to get column information |
| 519 | + with postgres_connection(connection_string, self._extensions) as conn: |
| 520 | + with conn.cursor() as cur: |
| 521 | + if schema_name: |
| 522 | + cur.execute( |
| 523 | + psycopg.sql.SQL(""" |
| 524 | + SELECT |
| 525 | + c.column_name, |
| 526 | + c.data_type, |
| 527 | + c.udt_name, |
| 528 | + CASE |
| 529 | + WHEN c.data_type = 'USER-DEFINED' AND c.udt_name = 'vector' |
| 530 | + THEN a.atttypmod |
| 531 | + ELSE NULL |
| 532 | + END as vector_dimension |
| 533 | + FROM information_schema.columns c |
| 534 | + JOIN pg_class cls ON cls.relname = c.table_name |
| 535 | + JOIN pg_namespace nsp ON nsp.oid = cls.relnamespace AND nsp.nspname = c.table_schema |
| 536 | + LEFT JOIN pg_attribute a ON a.attrelid = cls.oid AND a.attname = c.column_name |
| 537 | + WHERE c.table_schema = {} AND c.table_name = {} |
| 538 | + ORDER BY c.ordinal_position |
| 539 | + """).format(psycopg.sql.Literal(schema_name), psycopg.sql.Literal(table_name)), |
| 540 | + ) |
| 541 | + else: |
| 542 | + cur.execute( |
| 543 | + psycopg.sql.SQL(""" |
| 544 | + SELECT |
| 545 | + c.column_name, |
| 546 | + c.data_type, |
| 547 | + c.udt_name, |
| 548 | + CASE |
| 549 | + WHEN c.data_type = 'USER-DEFINED' AND c.udt_name = 'vector' |
| 550 | + THEN a.atttypmod |
| 551 | + ELSE NULL |
| 552 | + END as vector_dimension |
| 553 | + FROM information_schema.columns c |
| 554 | + JOIN pg_class cls ON cls.relname = c.table_name |
| 555 | + LEFT JOIN pg_attribute a ON a.attrelid = cls.oid AND a.attname = c.column_name |
| 556 | + WHERE c.table_name = {} |
| 557 | + ORDER BY c.ordinal_position |
| 558 | + """).format(psycopg.sql.Literal(table_name)), |
| 559 | + ) |
| 560 | + |
| 561 | + columns = cur.fetchall() |
| 562 | + |
| 563 | + # If no columns found, fall back to data-based inference |
| 564 | + if not columns: |
| 565 | + return self.read().schema() |
| 566 | + |
| 567 | + # Build schema from database metadata |
| 568 | + fields = [] |
| 569 | + for column_name, data_type, udt_name, vector_dimension in columns: |
| 570 | + if data_type == "USER-DEFINED" and udt_name == "vector": |
| 571 | + # This is a pgvector column, convert to embedding type |
| 572 | + # vector_dimension from atttypmod contains the dimension information |
| 573 | + # For pgvector, atttypmod stores the dimension directly |
| 574 | + dimension = vector_dimension if vector_dimension and vector_dimension > 0 else 0 |
| 575 | + |
| 576 | + if dimension > 0: |
| 577 | + fields.append(Field.create(column_name, DataType.embedding(DataType.float32(), dimension))) |
| 578 | + else: |
| 579 | + # Fallback to list if we can't determine dimension |
| 580 | + fields.append(Field.create(column_name, DataType.list(DataType.float32()))) |
| 581 | + else: |
| 582 | + # For non-vector columns, try direct PostgreSQL type mapping first |
| 583 | + try: |
| 584 | + # Attempt to map PostgreSQL type directly to Daft type |
| 585 | + inferred_dtype = DataType.from_sql(data_type) |
| 586 | + fields.append(Field.create(column_name, inferred_dtype)) |
| 587 | + except Exception: |
| 588 | + # Fall back to data-based inference for unmappable types |
| 589 | + # This is inefficient but ensures we get the correct types |
| 590 | + if schema_name: |
| 591 | + single_col_query = psycopg.sql.SQL("SELECT {} FROM {}.{} LIMIT 1").format( |
| 592 | + psycopg.sql.Identifier(column_name), |
| 593 | + psycopg.sql.Identifier(schema_name), |
| 594 | + psycopg.sql.Identifier(table_name), |
| 595 | + ) |
| 596 | + else: |
| 597 | + single_col_query = psycopg.sql.SQL("SELECT {} FROM {} LIMIT 1").format( |
| 598 | + psycopg.sql.Identifier(column_name), |
| 599 | + psycopg.sql.Identifier(table_name), |
| 600 | + ) |
| 601 | + |
| 602 | + try: |
| 603 | + single_col_df = read_sql(single_col_query.as_string(), connection_string) |
| 604 | + inferred_dtype = single_col_df.schema()[column_name].dtype |
| 605 | + fields.append(Field.create(column_name, inferred_dtype)) |
| 606 | + except Exception: |
| 607 | + # If inference fails, use string as fallback |
| 608 | + fields.append(Field.create(column_name, DataType.string())) |
| 609 | + |
| 610 | + return Schema._from_fields(fields) |
484 | 611 |
|
485 | 612 | @staticmethod |
486 | 613 | def _from_obj(obj: object) -> PostgresTable: |
@@ -510,12 +637,20 @@ def read( |
510 | 637 |
|
511 | 638 | query = psycopg.sql.SQL("SELECT * FROM {}").format(quoted_full_table) |
512 | 639 |
|
513 | | - return read_sql( |
| 640 | + df = read_sql( |
514 | 641 | query.as_string(), |
515 | 642 | connection_string, |
516 | 643 | **options, |
517 | 644 | ) |
518 | 645 |
|
| 646 | + # Cast any vector columns that were read as lists to embeddings |
| 647 | + schema = self.schema() # Use our custom schema method |
| 648 | + for field in schema: |
| 649 | + if field.dtype.is_embedding(): |
| 650 | + df = df.with_column(field.name, df[field.name].cast(field.dtype)) |
| 651 | + |
| 652 | + return df |
| 653 | + |
519 | 654 | def append(self, df: DataFrame, **options: Any) -> None: |
520 | 655 | """Append the DataFrame to the table.""" |
521 | 656 | connection_string, identifier = self._inner |
|
0 commit comments