Skip to content

Commit 81b46cb

Browse files
committed
give read_table the same treatment
1 parent 9964b7f commit 81b46cb

File tree

4 files changed

+31
-10
lines changed

4 files changed

+31
-10
lines changed

examples/datafusion-ffi-example/python/tests/_test_table_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,7 @@ def test_table_loading():
4040
]
4141

4242
assert result == expected
43+
44+
result = ctx.read_table(table).collect()
45+
result = [r.column(0) for r in result]
46+
assert result == expected

python/datafusion/context.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,14 +1177,11 @@ def read_avro(
11771177
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
11781178
)
11791179

1180-
def read_table(self, table: Table) -> DataFrame:
1181-
"""Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table.
1182-
1183-
For a :py:class:`~datafusion.catalog.Table` such as a
1184-
:py:class:`~datafusion.catalog.ListingTable`, create a
1185-
:py:class:`~datafusion.dataframe.DataFrame`.
1186-
"""
1187-
return DataFrame(self.ctx.read_table(table._inner))
1180+
def read_table(
1181+
self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset
1182+
) -> DataFrame:
1183+
"""Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table."""
1184+
return DataFrame(self.ctx.read_table(table))
11881185

11891186
def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11901187
"""Execute the ``plan`` and return the results."""

python/tests/test_context.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def test_register_table(ctx, database):
312312
assert public.names() == {"csv", "csv1", "csv2", "csv3"}
313313

314314

315-
def test_read_table(ctx, database):
315+
def test_read_table_from_catalog(ctx, database):
316316
default = ctx.catalog()
317317
public = default.schema("public")
318318
assert public.names() == {"csv", "csv1", "csv2"}
@@ -322,6 +322,25 @@ def test_read_table(ctx, database):
322322
table_df.show()
323323

324324

325+
def test_read_table_from_df(ctx):
326+
df = ctx.from_pydict({"a": [1, 2]})
327+
result = ctx.read_table(df).collect()
328+
assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]
329+
330+
331+
def test_read_table_from_dataset(ctx):
332+
batch = pa.RecordBatch.from_arrays(
333+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
334+
names=["a", "b"],
335+
)
336+
dataset = ds.dataset([batch])
337+
338+
result = ctx.read_table(dataset).collect()
339+
340+
assert result[0].column(0) == pa.array([1, 2, 3])
341+
assert result[0].column(1) == pa.array([4, 5, 6])
342+
343+
325344
def test_deregister_table(ctx, database):
326345
default = ctx.catalog()
327346
public = default.schema("public")

src/context.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,8 @@ impl PySessionContext {
10761076
Ok(PyDataFrame::new(df))
10771077
}
10781078

1079-
pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
1079+
pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
1080+
let table = PyTable::new(&table)?;
10801081
let df = self.ctx.read_table(table.table())?;
10811082
Ok(PyDataFrame::new(df))
10821083
}

0 commit comments

Comments
 (0)