Skip to content

Commit 7445dbe

Browse files
authored
fix(pkg-py): simplify DataSource inheritance/interface (#111)
* fix(pkg-py): simplify DataSource inheritance/interface * Address feedback
1 parent ce0d0d3 commit 7445dbe

File tree

1 file changed

+33
-21
lines changed

1 file changed

+33
-21
lines changed

pkg-py/src/querychat/datasource.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, ClassVar, Protocol, runtime_checkable
3+
from abc import ABC, abstractmethod
4+
from typing import TYPE_CHECKING
45

56
import duckdb
67
import narwhals.stable.v1 as nw
@@ -13,16 +14,26 @@
1314
from sqlalchemy.engine import Connection, Engine
1415

1516

16-
@runtime_checkable
17-
class DataSource(Protocol):
18-
db_engine: ClassVar[str]
17+
class DataSource(ABC):
18+
"""
19+
An abstract class defining the interface for data sources used by QueryChat.
20+
21+
Attributes
22+
----------
23+
table_name
24+
Name of the table to be used in SQL queries.
25+
26+
"""
27+
1928
table_name: str
2029

30+
@abstractmethod
2131
def get_db_type(self) -> str:
22-
"""Get the database type."""
32+
"""Name for the database behind the SQL execution."""
2333
...
2434

25-
def get_schema(self, *, categorical_threshold) -> str:
35+
@abstractmethod
36+
def get_schema(self, *, categorical_threshold: int) -> str:
2637
"""
2738
Return schema information about the table as a string.
2839
@@ -41,6 +52,7 @@ def get_schema(self, *, categorical_threshold) -> str:
4152
"""
4253
...
4354

55+
@abstractmethod
4456
def execute_query(self, query: str) -> pd.DataFrame:
4557
"""
4658
Execute SQL query and return results as DataFrame.
@@ -58,6 +70,7 @@ def execute_query(self, query: str) -> pd.DataFrame:
5870
"""
5971
...
6072

73+
@abstractmethod
6174
def get_data(self) -> pd.DataFrame:
6275
"""
6376
Return the unfiltered data as a DataFrame.
@@ -71,20 +84,9 @@ def get_data(self) -> pd.DataFrame:
7184
...
7285

7386

74-
class DataSourceBase:
75-
"""Base class for DataSource implementations."""
76-
77-
db_engine: ClassVar[str] = "standard"
78-
79-
def get_db_type(self) -> str:
80-
"""Get the database type."""
81-
return self.db_engine
82-
83-
84-
class DataFrameSource(DataSourceBase):
87+
class DataFrameSource(DataSource):
8588
"""A DataSource implementation that wraps a pandas DataFrame using DuckDB."""
8689

87-
db_engine: ClassVar[str] = "DuckDB"
8890
_df: nw.DataFrame | nw.LazyFrame
8991

9092
def __init__(self, df: IntoFrame, table_name: str):
@@ -105,6 +107,18 @@ def __init__(self, df: IntoFrame, table_name: str):
105107
# TODO(@gadenbuie): If the data frame is already SQL-backed, maybe we shouldn't be making a new copy here.
106108
self._conn.register(table_name, self._df.lazy().collect().to_pandas())
107109

110+
def get_db_type(self) -> str:
111+
"""
112+
Get the database type.
113+
114+
Returns
115+
-------
116+
:
117+
The string "DuckDB"
118+
119+
"""
120+
return "DuckDB"
121+
108122
def get_schema(self, *, categorical_threshold: int) -> str:
109123
"""
110124
Generate schema information from DataFrame.
@@ -199,7 +213,7 @@ def get_data(self) -> pd.DataFrame:
199213
return self._df.lazy().collect().to_pandas()
200214

201215

202-
class SQLAlchemySource(DataSourceBase):
216+
class SQLAlchemySource(DataSource):
203217
"""
204218
A DataSource implementation that supports multiple SQL databases via
205219
SQLAlchemy.
@@ -208,8 +222,6 @@ class SQLAlchemySource(DataSourceBase):
208222
and Databricks.
209223
"""
210224

211-
db_engine: ClassVar[str] = "SQLAlchemy"
212-
213225
def __init__(self, engine: Engine, table_name: str):
214226
"""
215227
Initialize with a SQLAlchemy engine.

0 commit comments

Comments
 (0)