Skip to content

Commit d91b738

Browse files
add a test to capture the bug
1 parent 766e2ed commit d91b738

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

python/datafusion/tests/test_context.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,54 @@ def test_dataset_filter(ctx, capfd):
343343
assert result[0].column(1) == pa.array([-3])
344344

345345

346+
def test_dataset_count(ctx):
347+
# `datafusion-python` issue: https://github.com/apache/datafusion-python/issues/800
348+
# probably related to:
349+
# - [Support RecordBatch with zero columns but non zero row count](https://github.com/apache/arrow-rs/issues/1536)
350+
# * PR: https://github.com/apache/arrow-rs/pull/1552
351+
# - https://github.com/apache/arrow-rs/issues/1783
352+
batch = pa.RecordBatch.from_arrays(
353+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
354+
names=["a", "b"],
355+
)
356+
dataset = ds.dataset([batch])
357+
ctx.register_dataset("t", dataset)
358+
359+
# The bug occurss in both the dataframe and SQL api
360+
df = ctx.table("t")
361+
assert df.count() == 3
362+
363+
count = ctx.sql("SELECT COUNT(*) FROM t")
364+
365+
# print(count.explain(verbose=False))
366+
# +---------------+----------------------------------------------------------------------------+
367+
# | plan_type | plan |
368+
# +---------------+----------------------------------------------------------------------------+
369+
# | logical_plan | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |
370+
# | | TableScan: t projection=[] |
371+
# | physical_plan | AggregateExec: mode=Final, gby=[], aggr=[count(*)] |
372+
# | | CoalescePartitionsExec |
373+
# | | AggregateExec: mode=Partial, gby=[], aggr=[count(*)] |
374+
# | | RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 |
375+
# | | DatasetExec: number_of_fragments=1, projection=[] |
376+
# | | |
377+
# +---------------+----------------------------------------------------------------------------+
378+
379+
count = count.collect()
380+
assert count[0].column(0) == pa.array([3])
381+
382+
# file_path = "./examples/tpch/data/lineitem.parquet"
383+
# pyarrow_dataset = ds.dataset([file_path])
384+
385+
# ctx.register_dataset("pyarrow_dataset", pyarrow_dataset)
386+
# df = ctx.table("pyarrow_dataset").select("l_orderkey", "l_partkey", "l_linenumber")
387+
388+
# df.limit(3).show()
389+
390+
# this is the line that causes the error
391+
# assert df.count() == 100
392+
393+
346394
def test_pyarrow_predicate_pushdown_is_null(ctx, capfd):
347395
"""Ensure that pyarrow filter gets pushed down for `IsNull`"""
348396
# create a RecordBatch and register it as a pyarrow.dataset.Dataset

0 commit comments

Comments
 (0)