|
| 1 | +import json |
| 2 | + |
| 3 | +import pandas as pd |
| 4 | +import duckdb |
| 5 | + |
| 6 | +from data_formulator.data_loader.external_data_loader import ExternalDataLoader, sanitize_table_name |
| 7 | +from typing import Dict, Any, List |
| 8 | + |
| 9 | +class PostgreSQLDataLoader(ExternalDataLoader): |
| 10 | + |
| 11 | + @staticmethod |
| 12 | + def list_params() -> List[Dict[str, Any]]: |
| 13 | + params_list = [ |
| 14 | + {"name": "user", "type": "string", "required": True, "default": "postgres", "description": "PostgreSQL username"}, |
| 15 | + {"name": "password", "type": "string", "required": False, "default": "", "description": "leave blank for no password"}, |
| 16 | + {"name": "host", "type": "string", "required": True, "default": "localhost", "description": "PostgreSQL host"}, |
| 17 | + {"name": "port", "type": "string", "required": False, "default": "5432", "description": "PostgreSQL port"}, |
| 18 | + {"name": "database", "type": "string", "required": True, "default": "postgres", "description": "PostgreSQL database name"} |
| 19 | + ] |
| 20 | + return params_list |
| 21 | + |
| 22 | + @staticmethod |
| 23 | + def auth_instructions() -> str: |
| 24 | + return "Provide your PostgreSQL connection details. The user must have SELECT permissions on the tables you want to access." |
| 25 | + |
| 26 | + def __init__(self, params: Dict[str, Any], duck_db_conn: duckdb.DuckDBPyConnection): |
| 27 | + self.params = params |
| 28 | + self.duck_db_conn = duck_db_conn |
| 29 | + |
| 30 | + try: |
| 31 | + # Install and load the Postgres extension |
| 32 | + self.duck_db_conn.install_extension("postgres") |
| 33 | + self.duck_db_conn.load_extension("postgres") |
| 34 | + |
| 35 | + # Prepare the connection string for Postgres |
| 36 | + port = self.params.get('port', '5432') |
| 37 | + password_part = f" password={self.params.get('password', '')}" if self.params.get('password') else "" |
| 38 | + attach_string = f"host={self.params['host']} port={port} user={self.params['user']}{password_part} dbname={self.params['database']}" |
| 39 | + |
| 40 | + # Detach existing postgres connection if it exists |
| 41 | + try: |
| 42 | + self.duck_db_conn.execute("DETACH mypostgresdb;") |
| 43 | + except: |
| 44 | + pass # Ignore if connection doesn't exist |
| 45 | + |
| 46 | + # Register Postgres connection |
| 47 | + self.duck_db_conn.execute(f"ATTACH '{attach_string}' AS mypostgresdb (TYPE postgres);") |
| 48 | + print(f"Successfully connected to PostgreSQL database: {self.params['database']}") |
| 49 | + |
| 50 | + except Exception as e: |
| 51 | + print(f"Failed to connect to PostgreSQL: {e}") |
| 52 | + raise |
| 53 | + |
| 54 | + def list_tables(self): |
| 55 | + try: |
| 56 | + # Query tables through DuckDB's attached PostgreSQL connection |
| 57 | + tables_df = self.duck_db_conn.execute(""" |
| 58 | + SELECT table_schema as schemaname, table_name as tablename |
| 59 | + FROM mypostgresdb.information_schema.tables |
| 60 | + WHERE table_schema NOT IN ('information_schema', 'pg_catalog', 'pg_toast') |
| 61 | + AND table_schema NOT LIKE '%_intern%' |
| 62 | + AND table_schema NOT LIKE '%timescaledb%' |
| 63 | + AND table_name NOT LIKE '%/%' |
| 64 | + AND table_type = 'BASE TABLE' |
| 65 | + ORDER BY table_schema, table_name |
| 66 | + """).fetch_df() |
| 67 | + |
| 68 | + print(f"Found tables: {tables_df}") |
| 69 | + |
| 70 | + results = [] |
| 71 | + |
| 72 | + for schema, table_name in tables_df.values: |
| 73 | + full_table_name = f"mypostgresdb.{schema}.{table_name}" |
| 74 | + |
| 75 | + try: |
| 76 | + # Get column information using DuckDB's DESCRIBE |
| 77 | + columns_df = self.duck_db_conn.execute(f"DESCRIBE {full_table_name}").df() |
| 78 | + columns = [{ |
| 79 | + 'name': row['column_name'], |
| 80 | + 'type': row['column_type'] |
| 81 | + } for _, row in columns_df.iterrows()] |
| 82 | + |
| 83 | + # Get sample data |
| 84 | + sample_df = self.duck_db_conn.execute(f"SELECT * FROM {full_table_name} LIMIT 10").df() |
| 85 | + sample_rows = json.loads(sample_df.to_json(orient="records")) |
| 86 | + |
| 87 | + # Get row count |
| 88 | + row_count = self.duck_db_conn.execute(f"SELECT COUNT(*) FROM {full_table_name}").fetchone()[0] |
| 89 | + |
| 90 | + table_metadata = { |
| 91 | + "row_count": row_count, |
| 92 | + "columns": columns, |
| 93 | + "sample_rows": sample_rows |
| 94 | + } |
| 95 | + |
| 96 | + results.append({ |
| 97 | + "name": full_table_name, |
| 98 | + "metadata": table_metadata |
| 99 | + }) |
| 100 | + |
| 101 | + except Exception as e: |
| 102 | + print(f"Error processing table {full_table_name}: {e}") |
| 103 | + continue |
| 104 | + |
| 105 | + return results |
| 106 | + |
| 107 | + except Exception as e: |
| 108 | + print(f"Error listing tables: {e}") |
| 109 | + return [] |
| 110 | + |
| 111 | + def ingest_data(self, table_name: str, name_as: str | None = None, size: int = 1000000): |
| 112 | + # Create table in the main DuckDB database from Postgres data |
| 113 | + if name_as is None: |
| 114 | + name_as = table_name.split('.')[-1] |
| 115 | + |
| 116 | + name_as = sanitize_table_name(name_as) |
| 117 | + |
| 118 | + self.duck_db_conn.execute(f""" |
| 119 | + CREATE OR REPLACE TABLE main.{name_as} AS |
| 120 | + SELECT * FROM {table_name} |
| 121 | + LIMIT {size} |
| 122 | + """) |
| 123 | + |
| 124 | + def view_query_sample(self, query: str) -> str: |
| 125 | + return json.loads(self.duck_db_conn.execute(query).df().head(10).to_json(orient="records")) |
| 126 | + |
| 127 | + def ingest_data_from_query(self, query: str, name_as: str) -> pd.DataFrame: |
| 128 | + # Execute the query and get results as a DataFrame |
| 129 | + df = self.duck_db_conn.execute(query).df() |
| 130 | + # Use the base class's method to ingest the DataFrame |
| 131 | + self.ingest_df_to_duckdb(df, name_as) |
| 132 | + return df |
0 commit comments