diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 486159165..551cdab48 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -1867,6 +1867,53 @@ def test_to_arrow_table(df): assert set(pyarrow_table.column_names) == {"a", "b", "c"} +def test_parquet_non_null_column_to_pyarrow(ctx, tmp_path): + path = tmp_path.joinpath("t.parquet") + + ctx.sql("create table t_(a int not null)").collect() + ctx.sql("insert into t_ values (1), (2), (3)").collect() + ctx.sql(f"copy (select * from t_) to '{path}'").collect() + + ctx.register_parquet("t", path) + pyarrow_table = ctx.sql("select max(a) as m from t").to_arrow_table() + assert pyarrow_table.to_pydict() == {"m": [3]} + + +def test_parquet_empty_batch_to_pyarrow(ctx, tmp_path): + path = tmp_path.joinpath("t.parquet") + + ctx.sql("create table t_(a int not null)").collect() + ctx.sql("insert into t_ values (1), (2), (3)").collect() + ctx.sql(f"copy (select * from t_) to '{path}'").collect() + + ctx.register_parquet("t", path) + pyarrow_table = ctx.sql("select * from t limit 0").to_arrow_table() + assert pyarrow_table.schema == pa.schema( + [ + pa.field("a", pa.int32(), nullable=False), + ] + ) + + +def test_parquet_null_aggregation_to_pyarrow(ctx, tmp_path): + path = tmp_path.joinpath("t.parquet") + + ctx.sql("create table t_(a int not null)").collect() + ctx.sql("insert into t_ values (1), (2), (3)").collect() + ctx.sql(f"copy (select * from t_) to '{path}'").collect() + + ctx.register_parquet("t", path) + pyarrow_table = ctx.sql( + "select max(a) as m from (select * from t where a < 0)" + ).to_arrow_table() + assert pyarrow_table.to_pydict() == {"m": [None]} + assert pyarrow_table.schema == pa.schema( + [ + pa.field("m", pa.int32(), nullable=True), + ] + ) + + def test_execute_stream(df): stream = df.execute_stream() assert all(batch is not None for batch in stream) diff --git a/src/dataframe.rs b/src/dataframe.rs index 21eb6e0e2..8447a4182 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -1024,11 +1024,18 @@ impl PyDataFrame { /// Collect the batches and pass to Arrow Table fn to_arrow_table(&self, py: Python<'_>) -> PyResult { let batches = self.collect(py)?.into_pyobject(py)?; - let schema = self.schema().into_pyobject(py)?; + + // only use the DataFrame's schema if there are no batches, otherwise let the schema be + // determined from the batches (avoids some inconsistencies with nullable columns) + let args = if batches.len()? == 0 { + let schema = self.schema().into_pyobject(py)?; + PyTuple::new(py, &[batches, schema])? + } else { + PyTuple::new(py, &[batches])? + }; // Instantiate pyarrow Table object and use its from_batches method let table_class = py.import("pyarrow")?.getattr("Table")?; - let args = PyTuple::new(py, &[batches, schema])?; let table: PyObject = table_class.call_method1("from_batches", args)?.into(); Ok(table) }