Skip to content

Commit e67dac0

Browse files
committed
fix: improve table registration logic to handle raw tables and enhance error handling
1 parent 92b22ee commit e67dac0

File tree

2 files changed

+34
-16
lines changed

2 files changed

+34
-16
lines changed

python/tests/test_catalog.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,28 @@ def test_python_table_provider(ctx: SessionContext):
164164
assert schema.table_names() == {"table4"}
165165

166166

167+
def test_register_raw_table_without_capsule(ctx: SessionContext, database, monkeypatch):
168+
schema = ctx.catalog().schema("public")
169+
raw_table = schema.table("csv").table
170+
171+
def fail(*args, **kwargs):
172+
raise AssertionError("RawTable capsule path should not be invoked")
173+
174+
monkeypatch.setattr(type(raw_table), "__datafusion_table_provider__", fail)
175+
176+
schema.register_table("csv_copy", raw_table)
177+
178+
# Restore the original implementation to avoid interfering with later assertions
179+
monkeypatch.undo()
180+
181+
batches = ctx.sql("select count(*) from csv_copy").collect()
182+
183+
assert len(batches) == 1
184+
assert batches[0].column(0) == pa.array([4])
185+
186+
schema.deregister_table("csv_copy")
187+
188+
167189
def test_in_end_to_end_python_providers(ctx: SessionContext):
168190
"""Test registering all python providers and running a query against them."""
169191

src/catalog.rs

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,9 @@ impl PySchema {
197197
}
198198

199199
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
200-
let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
200+
let provider = if let Ok(py_table) = table_provider.extract::<PyTable>() {
201+
py_table.table
202+
} else if table_provider.hasattr("__datafusion_table_provider__")? {
201203
let capsule = table_provider
202204
.getattr("__datafusion_table_provider__")?
203205
.call0()?;
@@ -208,14 +210,9 @@ impl PySchema {
208210
let provider: ForeignTableProvider = provider.into();
209211
Arc::new(provider) as Arc<dyn TableProvider>
210212
} else {
211-
match table_provider.extract::<PyTable>() {
212-
Ok(py_table) => py_table.table,
213-
Err(_) => {
214-
let py = table_provider.py();
215-
let provider = Dataset::new(&table_provider, py)?;
216-
Arc::new(provider) as Arc<dyn TableProvider>
217-
}
218-
}
213+
let py = table_provider.py();
214+
let provider = Dataset::new(&table_provider, py)?;
215+
Arc::new(provider) as Arc<dyn TableProvider>
219216
};
220217

221218
let _ = self
@@ -322,6 +319,10 @@ impl RustWrappedPySchemaProvider {
322319
return Ok(None);
323320
}
324321

322+
if let Ok(inner_table) = py_table.extract::<PyTable>() {
323+
return Ok(Some(inner_table.table));
324+
}
325+
325326
if py_table.hasattr("__datafusion_table_provider__")? {
326327
let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?;
327328
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
@@ -338,13 +339,8 @@ impl RustWrappedPySchemaProvider {
338339
}
339340
}
340341

341-
match py_table.extract::<PyTable>() {
342-
Ok(py_table) => Ok(Some(py_table.table)),
343-
Err(_) => {
344-
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
345-
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
346-
}
347-
}
342+
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
343+
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
348344
}
349345
})
350346
}

0 commit comments

Comments
 (0)