Skip to content

Commit 9fb2244

Browse files
committed
optimize the process for duckdb input
1 parent aec5f85 commit 9fb2244

File tree

4 files changed

+151
-91
lines changed

4 files changed

+151
-91
lines changed

stemflow/lazyloading/open_db_connection.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,53 @@ def duckdb_config(max_mem, joblib_tmp_dir):
4141
"memory_limit": max_mem,
4242
"temp_directory": os.path.join(joblib_tmp_dir, 'duckdb'),
4343
}
44-
44+
45+
46+
47+
from contextlib import contextmanager
48+
import duckdb
49+
import pandas as pd
50+
51+
def _as_relation(con, obj, view_name, attach_alias):
52+
"""Normalize obj into a relation visible in `con` under `view_name`."""
53+
if isinstance(obj, pd.DataFrame):
54+
# keep pandas behavior same as before: return the DF
55+
return obj
56+
if isinstance(obj, str) and obj.endswith(".duckdb"):
57+
# attach the DB file under its own alias, then expose its (first) table as a view
58+
con.execute(f"ATTACH '{obj}' AS {attach_alias} (READ_ONLY)")
59+
tbl = con.sql(
60+
f"""
61+
SELECT table_name FROM {attach_alias}.information_schema.tables
62+
WHERE table_schema='main' LIMIT 1
63+
"""
64+
).fetchone()[0]
65+
rel = con.sql(f"SELECT * FROM {attach_alias}.main.{tbl}")
66+
rel.create_view(view_name)
67+
return rel
68+
if isinstance(obj, str) and obj.endswith(".parquet"):
69+
rel = con.read_parquet(obj, hive_partitioning=False)
70+
rel.create_view(view_name)
71+
return rel
72+
raise TypeError("Input must be a pandas DataFrame, .duckdb, or .parquet path.")
73+
74+
75+
@contextmanager
76+
def open_both_Xy_db_connection(X_train, y_train, duckdb_config):
77+
"""
78+
Open a DuckDB connection. With one source (X_train), behaves like before and yields (X_obj, con).
79+
With two sources (X_train, Y_train), yields (X_rel_or_df, Y_rel_or_df, con) sharing the SAME connection.
80+
"""
81+
con = None
82+
try:
83+
# one shared connection for both cases
84+
con = duckdb.connect(config=duckdb_config)
85+
# dual-source mode: expose BOTH in the SAME connection
86+
X_obj = _as_relation(con, X_train, "X_df", "xdb")
87+
Y_obj = _as_relation(con, y_train, "y_df", "ydb")
88+
yield X_obj, Y_obj, con
89+
90+
finally:
91+
if con is not None:
92+
con.close()
93+

0 commit comments

Comments
 (0)