Skip to content

Commit 1c2c6f6

Browse files
committed
Fix inconsistent schemas when converting to pyarrow
1 parent 276dc6a commit 1c2c6f6

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

python/tests/test_dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,6 +1867,18 @@ def test_to_arrow_table(df):
18671867
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
18681868

18691869

1870+
def test_parquet_non_null_column_to_pyarrow(ctx, tmp_path):
1871+
path = tmp_path.joinpath("t.parquet")
1872+
1873+
ctx.sql("create table t_(a int not null)").collect()
1874+
ctx.sql("insert into t_ values (1), (2), (3)").collect()
1875+
ctx.sql(f"copy (select * from t_) to '{path}'").collect()
1876+
1877+
ctx.register_parquet("t", path)
1878+
pyarrow_table = ctx.sql("select max(a) as m from t").to_arrow_table()
1879+
assert pyarrow_table.to_pydict() == {"m": [3]}
1880+
1881+
18701882
def test_execute_stream(df):
18711883
stream = df.execute_stream()
18721884
assert all(batch is not None for batch in stream)

src/dataframe.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,11 +1024,18 @@ impl PyDataFrame {
10241024
/// Collect the batches and pass to Arrow Table
10251025
fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
10261026
let batches = self.collect(py)?.into_pyobject(py)?;
1027-
let schema = self.schema().into_pyobject(py)?;
1027+
1028+
// only use the DataFrame's schema if there are no batches, otherwise let the schema be
1029+
// determined from the batches (avoids some inconsistencies with nullable columns)
1030+
let args = if batches.len()? == 0 {
1031+
let schema = self.schema().into_pyobject(py)?;
1032+
PyTuple::new(py, &[batches, schema])?
1033+
} else {
1034+
PyTuple::new(py, &[batches])?
1035+
};
10281036

10291037
// Instantiate pyarrow Table object and use its from_batches method
10301038
let table_class = py.import("pyarrow")?.getattr("Table")?;
1031-
let args = PyTuple::new(py, &[batches, schema])?;
10321039
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
10331040
Ok(table)
10341041
}

0 commit comments

Comments
 (0)