-
Notifications
You must be signed in to change notification settings - Fork 128
SessionContext: automatically register Python (Arrow/Pandas/Polars) objects referenced in SQL #1247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
66d74a3
65e4492
53a62f7
1f36102
92dde5b
db2d239
8fc3e1c
b733408
fb3dadb
6454b8c
dc1b392
904c1ca
1764a57
b9041ba
ac1d6e1
15b5cec
dc06874
78c26cc
57d6380
1a1a5b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -228,6 +228,33 @@ Core Classes | |
* :py:meth:`~datafusion.SessionContext.from_pandas` - Create from Pandas DataFrame | ||
* :py:meth:`~datafusion.SessionContext.from_arrow` - Create from Arrow data | ||
|
||
``SessionContext`` can automatically resolve SQL table names that match | ||
in-scope Python data objects. When automatic lookup is enabled, a query | ||
such as ``ctx.sql("SELECT * FROM pdf")`` will register a pandas or | ||
PyArrow object named ``pdf`` without calling | ||
:py:meth:`~datafusion.SessionContext.from_pandas` or | ||
:py:meth:`~datafusion.SessionContext.from_arrow` explicitly. This requires | ||
the corresponding library (``pandas`` for pandas objects, ``pyarrow`` for | ||
Arrow objects) to be installed. | ||
|
||
|
||
.. code-block:: python | ||
|
||
import pandas as pd | ||
from datafusion import SessionContext | ||
|
||
ctx = SessionContext(auto_register_python_objects=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a long parameter; what do we think about turning it on by default and/or choosing a shorter name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flipping this on by default would change long-standing failure modes—queries that currently raise “table not found” would start consulting the caller’s scope, which could mask mistakes or introduce non-deterministic behavior when multiple similarly named objects exist. |
||
pdf = pd.DataFrame({"value": [1, 2, 3]}) | ||
|
||
df = ctx.sql("SELECT SUM(value) AS total FROM pdf") | ||
print(df.to_pandas()) # automatically registers `pdf` | ||
|
||
Automatic lookup is disabled by default. Enable it by passing | ||
``auto_register_python_objects=True`` when constructing the session or by | ||
configuring :py:class:`~datafusion.SessionConfig` with | ||
:py:meth:`~datafusion.SessionConfig.with_python_table_lookup`. Use | ||
:py:meth:`~datafusion.SessionContext.set_python_table_lookup` to toggle the | ||
behaviour at runtime. | ||
|
||
See: :py:class:`datafusion.SessionContext` | ||
|
||
Expression Classes | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,10 @@ | |
|
||
from __future__ import annotations | ||
|
||
import inspect | ||
import re | ||
import warnings | ||
import weakref | ||
from typing import TYPE_CHECKING, Any, Protocol | ||
|
||
try: | ||
|
@@ -101,6 +104,7 @@ def __init__(self, config_options: dict[str, str] | None = None) -> None: | |
config_options: Configuration options. | ||
""" | ||
self.config_internal = SessionConfigInternal(config_options) | ||
self._python_table_lookup = False | ||
|
||
def with_create_default_catalog_and_schema( | ||
self, enabled: bool = True | ||
|
@@ -270,6 +274,11 @@ def with_parquet_pruning(self, enabled: bool = True) -> SessionConfig: | |
self.config_internal = self.config_internal.with_parquet_pruning(enabled) | ||
return self | ||
|
||
def with_python_table_lookup(self, enabled: bool = True) -> SessionConfig: | ||
"""Enable implicit table lookup for Python objects when running SQL.""" | ||
self._python_table_lookup = enabled | ||
return self | ||
|
||
def set(self, key: str, value: str) -> SessionConfig: | ||
"""Set a configuration option. | ||
|
||
|
@@ -483,6 +492,8 @@ def __init__( | |
self, | ||
config: SessionConfig | None = None, | ||
runtime: RuntimeEnvBuilder | None = None, | ||
*, | ||
auto_register_python_objects: bool | None = None, | ||
) -> None: | ||
"""Main interface for executing queries with DataFusion. | ||
|
||
|
@@ -493,6 +504,12 @@ def __init__( | |
Args: | ||
config: Session configuration options. | ||
runtime: Runtime configuration options. | ||
auto_register_python_objects: Automatically register referenced | ||
Python objects (such as pandas or PyArrow data) when ``sql`` | ||
queries reference them by name. When omitted, this defaults to | ||
the value configured via | ||
:py:meth:`~datafusion.SessionConfig.with_python_table_lookup` | ||
(``False`` unless explicitly enabled). | ||
|
||
Example usage: | ||
|
||
|
@@ -504,10 +521,22 @@ def __init__( | |
ctx = SessionContext() | ||
df = ctx.read_csv("data.csv") | ||
""" | ||
config = config.config_internal if config is not None else None | ||
runtime = runtime.config_internal if runtime is not None else None | ||
self.ctx = SessionContextInternal( | ||
config.config_internal if config is not None else None, | ||
runtime.config_internal if runtime is not None else None, | ||
) | ||
|
||
# Determine the final value for python table lookup | ||
if auto_register_python_objects is not None: | ||
auto_python_table_lookup = auto_register_python_objects | ||
else: | ||
# Default to session config value or False if not configured | ||
auto_python_table_lookup = getattr(config, "_python_table_lookup", False) | ||
|
||
self.ctx = SessionContextInternal(config, runtime) | ||
self._auto_python_table_lookup = bool(auto_python_table_lookup) | ||
self._python_table_bindings: dict[ | ||
str, tuple[weakref.ReferenceType[Any] | None, int] | ||
] = {} | ||
|
||
def __repr__(self) -> str: | ||
"""Print a string representation of the Session Context.""" | ||
|
@@ -534,8 +563,27 @@ def enable_url_table(self) -> SessionContext: | |
klass = self.__class__ | ||
obj = klass.__new__(klass) | ||
obj.ctx = self.ctx.enable_url_table() | ||
obj._auto_python_table_lookup = getattr( | ||
self, "_auto_python_table_lookup", False | ||
) | ||
obj._python_table_bindings = getattr(self, "_python_table_bindings", {}).copy() | ||
return obj | ||
|
||
def set_python_table_lookup(self, enabled: bool = True) -> SessionContext: | ||
"""Enable or disable automatic registration of Python objects in SQL. | ||
|
||
Args: | ||
enabled: When ``True``, SQL queries automatically attempt to | ||
resolve missing table names by looking up Python objects in the | ||
caller's scope. Use ``False`` to require explicit registration | ||
of any referenced tables. | ||
|
||
Returns: | ||
The current :py:class:`SessionContext` instance for chaining. | ||
""" | ||
self._auto_python_table_lookup = enabled | ||
return self | ||
|
||
def register_object_store( | ||
self, schema: str, store: Any, host: str | None = None | ||
) -> None: | ||
|
@@ -600,9 +648,34 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: | |
Returns: | ||
DataFrame representation of the SQL query. | ||
""" | ||
if options is None: | ||
return DataFrame(self.ctx.sql(query)) | ||
return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) | ||
|
||
def _execute_sql() -> DataFrame: | ||
if options is None: | ||
return DataFrame(self.ctx.sql(query)) | ||
return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) | ||
|
||
auto_lookup_enabled = getattr(self, "_auto_python_table_lookup", False) | ||
|
||
if auto_lookup_enabled: | ||
self._refresh_python_table_bindings() | ||
|
||
while True: | ||
try: | ||
return _execute_sql() | ||
except Exception as err: # noqa: PERF203 | ||
if not auto_lookup_enabled: | ||
raise | ||
|
||
missing_tables = self._extract_missing_table_names(err) | ||
if not missing_tables: | ||
raise | ||
|
||
registered = self._register_python_tables(missing_tables) | ||
if not registered: | ||
raise | ||
|
||
# Retry to allow registering additional tables referenced in the query. | ||
continue | ||
|
||
def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: | ||
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text. | ||
|
@@ -619,6 +692,138 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: | |
""" | ||
return self.sql(query, options) | ||
|
||
@staticmethod | ||
def _extract_missing_table_names(err: Exception) -> list[str]: | ||
def _normalize(names: list[Any]) -> list[str]: | ||
tables: list[str] = [] | ||
for raw_name in names: | ||
if not raw_name: | ||
continue | ||
raw_str = str(raw_name) | ||
tables.append(raw_str.rsplit(".", 1)[-1]) | ||
return tables | ||
|
||
missing_tables = getattr(err, "missing_table_names", None) | ||
if missing_tables is not None: | ||
if isinstance(missing_tables, str): | ||
candidates: list[Any] = [missing_tables] | ||
else: | ||
try: | ||
candidates = list(missing_tables) | ||
except TypeError: | ||
candidates = [missing_tables] | ||
|
||
return _normalize(candidates) | ||
|
||
message = str(err) | ||
matches = set() | ||
for pattern in (r"table '([^']+)' not found", r"No table named '([^']+)'"): | ||
matches.update(re.findall(pattern, message)) | ||
|
||
return _normalize(list(matches)) | ||
|
||
def _register_python_tables(self, tables: list[str]) -> bool: | ||
registered_any = False | ||
for table_name in tables: | ||
if not table_name or self.table_exist(table_name): | ||
continue | ||
|
||
python_obj = self._lookup_python_object(table_name) | ||
if python_obj is None: | ||
continue | ||
|
||
if self._register_python_object(table_name, python_obj): | ||
registered_any = True | ||
|
||
return registered_any | ||
|
||
@staticmethod | ||
def _lookup_python_object(name: str) -> Any | None: | ||
frame = inspect.currentframe() | ||
try: | ||
frame = frame.f_back if frame is not None else None | ||
lower_name = name.lower() | ||
|
||
def _match(mapping: dict[str, Any]) -> Any | None: | ||
value = mapping.get(name) | ||
if value is not None: | ||
return value | ||
|
||
for key, candidate in mapping.items(): | ||
if ( | ||
isinstance(key, str) | ||
and key.lower() == lower_name | ||
and candidate is not None | ||
): | ||
return candidate | ||
|
||
return None | ||
|
||
while frame is not None: | ||
for scope in (frame.f_locals, frame.f_globals): | ||
match = _match(scope) | ||
if match is not None: | ||
return match | ||
frame = frame.f_back | ||
finally: | ||
del frame | ||
return None | ||
|
||
def _refresh_python_table_bindings(self) -> None: | ||
bindings = getattr(self, "_python_table_bindings", {}) | ||
for table_name, (obj_ref, cached_id) in list(bindings.items()): | ||
cached_obj = obj_ref() if obj_ref is not None else None | ||
current_obj = self._lookup_python_object(table_name) | ||
weakref_dead = obj_ref is not None and cached_obj is None | ||
id_mismatch = current_obj is not None and id(current_obj) != cached_id | ||
|
||
if not (weakref_dead or id_mismatch): | ||
continue | ||
|
||
self.deregister_table(table_name) | ||
|
||
if current_obj is None: | ||
bindings.pop(table_name, None) | ||
continue | ||
|
||
if self._register_python_object(table_name, current_obj): | ||
continue | ||
|
||
bindings.pop(table_name, None) | ||
|
||
def _register_python_object(self, name: str, obj: Any) -> bool: | ||
registered = False | ||
|
||
if isinstance(obj, DataFrame): | ||
self.register_view(name, obj) | ||
registered = True | ||
elif ( | ||
obj.__class__.__module__.startswith("polars.") | ||
and obj.__class__.__name__ == "DataFrame" | ||
): | ||
self.from_polars(obj, name=name) | ||
registered = True | ||
elif ( | ||
obj.__class__.__module__.startswith("pandas.") | ||
and obj.__class__.__name__ == "DataFrame" | ||
): | ||
self.from_pandas(obj, name=name) | ||
registered = True | ||
elif isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)) or ( | ||
hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__") | ||
): | ||
self.from_arrow(obj, name=name) | ||
registered = True | ||
Comment on lines
797
to
802
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO all of this should (or at least could) be replaced with hasattr(obj, "__arrow_c_stream__") to use the PyCapsule Interface. Unless we want to support old versions of Pandas and Polars? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. I will invert the if comparison to check for hasattr(obj, "__arrow_c_stream__") before falling back to checking for modules as there are older versions of Pandas (and maybe Polars) that don't support arrow_c_stream |
||
|
||
if registered: | ||
try: | ||
reference: weakref.ReferenceType[Any] | None = weakref.ref(obj) | ||
except TypeError: | ||
reference = None | ||
self._python_table_bindings[name] = (reference, id(obj)) | ||
|
||
return registered | ||
|
||
def create_dataframe( | ||
self, | ||
partitions: list[list[pa.RecordBatch]], | ||
|
@@ -756,6 +961,7 @@ def register_table(self, name: str, table: Table) -> None: | |
def deregister_table(self, name: str) -> None: | ||
"""Remove a table from the session.""" | ||
self.ctx.deregister_table(name) | ||
self._python_table_bindings.pop(name, None) | ||
|
||
def catalog_names(self) -> set[str]: | ||
"""Returns the list of catalogs in this context.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the registration temporary? Or after the query ends is
pdf
now still bound to the specific object?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registrations persist: once a variable is bound we cache a weak reference plus its id in _python_table_bindings. On every subsequent SQL call we refresh that cache—dropping the registration if the object has been garbage collected, reassigned, or otherwise moved—but as long as the original object is still alive the table name remains usable across queries.