diff --git a/pkg-py/src/querychat/datasource.py b/pkg-py/src/querychat/datasource.py index 8ece671d..0cdb7627 100644 --- a/pkg-py/src/querychat/datasource.py +++ b/pkg-py/src/querychat/datasource.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, Protocol, runtime_checkable +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import duckdb import narwhals.stable.v1 as nw @@ -13,16 +14,26 @@ from sqlalchemy.engine import Connection, Engine -@runtime_checkable -class DataSource(Protocol): - db_engine: ClassVar[str] +class DataSource(ABC): + """ + An abstract class defining the interface for data sources used by QueryChat. + + Attributes + ---------- + table_name + Name of the table to be used in SQL queries. + + """ + table_name: str + @abstractmethod def get_db_type(self) -> str: - """Get the database type.""" + """Name for the database behind the SQL execution.""" ... - def get_schema(self, *, categorical_threshold) -> str: + @abstractmethod + def get_schema(self, *, categorical_threshold: int) -> str: """ Return schema information about the table as a string. @@ -41,6 +52,7 @@ def get_schema(self, *, categorical_threshold) -> str: """ ... + @abstractmethod def execute_query(self, query: str) -> pd.DataFrame: """ Execute SQL query and return results as DataFrame. @@ -58,6 +70,7 @@ def execute_query(self, query: str) -> pd.DataFrame: """ ... + @abstractmethod def get_data(self) -> pd.DataFrame: """ Return the unfiltered data as a DataFrame. @@ -71,20 +84,9 @@ def get_data(self) -> pd.DataFrame: ... -class DataSourceBase: - """Base class for DataSource implementations.""" - - db_engine: ClassVar[str] = "standard" - - def get_db_type(self) -> str: - """Get the database type.""" - return self.db_engine - - -class DataFrameSource(DataSourceBase): +class DataFrameSource(DataSource): """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" - db_engine: ClassVar[str] = "DuckDB" _df: nw.DataFrame | nw.LazyFrame def __init__(self, df: IntoFrame, table_name: str): @@ -105,6 +107,18 @@ def __init__(self, df: IntoFrame, table_name: str): # TODO(@gadenbuie): If the data frame is already SQL-backed, maybe we shouldn't be making a new copy here. self._conn.register(table_name, self._df.lazy().collect().to_pandas()) + def get_db_type(self) -> str: + """ + Get the database type. + + Returns + ------- + : + The string "DuckDB" + + """ + return "DuckDB" + def get_schema(self, *, categorical_threshold: int) -> str: """ Generate schema information from DataFrame. @@ -199,7 +213,7 @@ def get_data(self) -> pd.DataFrame: return self._df.lazy().collect().to_pandas() -class SQLAlchemySource(DataSourceBase): +class SQLAlchemySource(DataSource): """ A DataSource implementation that supports multiple SQL databases via SQLAlchemy. @@ -208,8 +222,6 @@ class SQLAlchemySource(DataSourceBase): and Databricks. """ - db_engine: ClassVar[str] = "SQLAlchemy" - def __init__(self, engine: Engine, table_name: str): """ Initialize with a SQLAlchemy engine.