@@ -343,6 +343,54 @@ def test_dataset_filter(ctx, capfd):
343343    assert  result [0 ].column (1 ) ==  pa .array ([- 3 ])
344344
345345
346+ def  test_dataset_count (ctx ):
347+     # `datafusion-python` issue: https://github.com/apache/datafusion-python/issues/800 
348+     # probably related to: 
349+     #  - [Support RecordBatch with zero columns but non zero row count](https://github.com/apache/arrow-rs/issues/1536) 
350+     #    * PR: https://github.com/apache/arrow-rs/pull/1552 
351+     #  - https://github.com/apache/arrow-rs/issues/1783 
352+     batch  =  pa .RecordBatch .from_arrays (
353+         [pa .array ([1 , 2 , 3 ]), pa .array ([4 , 5 , 6 ])],
354+         names = ["a" , "b" ],
355+     )
356+     dataset  =  ds .dataset ([batch ])
357+     ctx .register_dataset ("t" , dataset )
358+ 
359+     # The bug occurss in both the dataframe and SQL api 
360+     df  =  ctx .table ("t" )
361+     assert  df .count () ==  3 
362+ 
363+     count  =  ctx .sql ("SELECT COUNT(*) FROM t" )
364+ 
365+     # print(count.explain(verbose=False)) 
366+     # +---------------+----------------------------------------------------------------------------+ 
367+     # | plan_type     | plan                                                                       | 
368+     # +---------------+----------------------------------------------------------------------------+ 
369+     # | logical_plan  | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]              | 
370+     # |               |   TableScan: t projection=[]                                               | 
371+     # | physical_plan | AggregateExec: mode=Final, gby=[], aggr=[count(*)]                         | 
372+     # |               |   CoalescePartitionsExec                                                   | 
373+     # |               |     AggregateExec: mode=Partial, gby=[], aggr=[count(*)]                   | 
374+     # |               |       RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 | 
375+     # |               |         DatasetExec: number_of_fragments=1, projection=[]                  | 
376+     # |               |                                                                            | 
377+     # +---------------+----------------------------------------------------------------------------+ 
378+ 
379+     count  =  count .collect ()
380+     assert  count [0 ].column (0 ) ==  pa .array ([3 ])
381+ 
382+     # file_path = "./examples/tpch/data/lineitem.parquet" 
383+     # pyarrow_dataset = ds.dataset([file_path]) 
384+ 
385+     # ctx.register_dataset("pyarrow_dataset", pyarrow_dataset) 
386+     # df = ctx.table("pyarrow_dataset").select("l_orderkey", "l_partkey", "l_linenumber") 
387+ 
388+     # df.limit(3).show() 
389+ 
390+     # this is the line that causes the error 
391+     # assert df.count() == 100 
392+ 
393+ 
346394def  test_pyarrow_predicate_pushdown_is_null (ctx , capfd ):
347395    """Ensure that pyarrow filter gets pushed down for `IsNull`""" 
348396    # create a RecordBatch and register it as a pyarrow.dataset.Dataset 
0 commit comments