Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions pkg-py/src/querychat/datasource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, Protocol, runtime_checkable
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import duckdb
import narwhals.stable.v1 as nw
Expand All @@ -13,16 +14,26 @@
from sqlalchemy.engine import Connection, Engine


@runtime_checkable
class DataSource(Protocol):
db_engine: ClassVar[str]
class DataSource(ABC):
"""
An abstract class defining the interface for data sources used by QueryChat.

Attributes
----------
table_name
Name of the table to be used in SQL queries.

"""

table_name: str

@abstractmethod
def get_db_type(self) -> str:
"""Get the database type."""
"""Name for the database behind the SQL execution."""
...

def get_schema(self, *, categorical_threshold) -> str:
@abstractmethod
def get_schema(self, *, categorical_threshold: int) -> str:
"""
Return schema information about the table as a string.

Expand All @@ -41,6 +52,7 @@ def get_schema(self, *, categorical_threshold) -> str:
"""
...

@abstractmethod
def execute_query(self, query: str) -> pd.DataFrame:
"""
Execute SQL query and return results as DataFrame.
Expand All @@ -58,6 +70,7 @@ def execute_query(self, query: str) -> pd.DataFrame:
"""
...

@abstractmethod
def get_data(self) -> pd.DataFrame:
"""
Return the unfiltered data as a DataFrame.
Expand All @@ -71,20 +84,9 @@ def get_data(self) -> pd.DataFrame:
...


class DataSourceBase:
"""Base class for DataSource implementations."""

db_engine: ClassVar[str] = "standard"

def get_db_type(self) -> str:
"""Get the database type."""
return self.db_engine


class DataFrameSource(DataSourceBase):
class DataFrameSource(DataSource):
"""A DataSource implementation that wraps a pandas DataFrame using DuckDB."""

db_engine: ClassVar[str] = "DuckDB"
_df: nw.DataFrame | nw.LazyFrame

def __init__(self, df: IntoFrame, table_name: str):
Expand All @@ -105,6 +107,18 @@ def __init__(self, df: IntoFrame, table_name: str):
# TODO(@gadenbuie): If the data frame is already SQL-backed, maybe we shouldn't be making a new copy here.
self._conn.register(table_name, self._df.lazy().collect().to_pandas())

def get_db_type(self) -> str:
"""
Get the database type.

Returns
-------
:
The string "DuckDB"

"""
return "DuckDB"

def get_schema(self, *, categorical_threshold: int) -> str:
"""
Generate schema information from DataFrame.
Expand Down Expand Up @@ -199,7 +213,7 @@ def get_data(self) -> pd.DataFrame:
return self._df.lazy().collect().to_pandas()


class SQLAlchemySource(DataSourceBase):
class SQLAlchemySource(DataSource):
"""
A DataSource implementation that supports multiple SQL databases via
SQLAlchemy.
Expand All @@ -208,8 +222,6 @@ class SQLAlchemySource(DataSourceBase):
and Databricks.
"""

db_engine: ClassVar[str] = "SQLAlchemy"

def __init__(self, engine: Engine, table_name: str):
"""
Initialize with a SQLAlchemy engine.
Expand Down