Skip to content

Commit 5a02fed

Browse files
authored
chore: change return type of Count() from u64 to i64 (#92)
1 parent 66fee4a commit 5a02fed

File tree

10 files changed

+104
-104
lines changed

10 files changed

+104
-104
lines changed

datafusion/core/src/optimizer/single_distinct_to_groupby.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ mod tests {
238238
.build()?;
239239

240240
// Should work
241-
let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):UInt64;N]\
242-
\n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\
241+
let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):Int64;N]\
242+
\n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):Int64;N]\
243243
\n Aggregate: groupBy=[[#test.b AS alias1]], aggr=[[]] [alias1:UInt32]\
244244
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
245245

@@ -255,8 +255,8 @@ mod tests {
255255
.aggregate(Vec::<Expr>::new(), vec![count_distinct(lit(2) * col("b"))])?
256256
.build()?;
257257

258-
let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):UInt64;N]\
259-
\n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\
258+
let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\
259+
\n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):Int64;N]\
260260
\n Aggregate: groupBy=[[Int32(2) * #test.b AS alias1]], aggr=[[]] [alias1:Int32]\
261261
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
262262

@@ -273,8 +273,8 @@ mod tests {
273273
.build()?;
274274

275275
// Should work
276-
let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N]\
277-
\n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N]\
276+
let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\
277+
\n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\
278278
\n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
279279
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
280280

@@ -294,7 +294,7 @@ mod tests {
294294
.build()?;
295295

296296
// Do nothing
297-
let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(DISTINCT #test.c)]] [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, COUNT(DISTINCT test.c):UInt64;N]\
297+
let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(DISTINCT #test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\
298298
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
299299

300300
assert_optimized_plan_eq(&plan, expected);
@@ -319,8 +319,8 @@ mod tests {
319319
)?
320320
.build()?;
321321
// Should work
322-
let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, MAX(DISTINCT test.b):UInt32;N]\
323-
\n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N, MAX(alias1):UInt32;N]\
322+
let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
323+
\n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\
324324
\n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
325325
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
326326

@@ -340,7 +340,7 @@ mod tests {
340340
.build()?;
341341

342342
// Do nothing
343-
let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(#test.c)]] [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, COUNT(test.c):UInt64;N]\
343+
let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(#test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\
344344
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
345345

346346
assert_optimized_plan_eq(&plan, expected);

datafusion/core/src/physical_optimizer/aggregate_statistics.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ fn take_optimizable_table_count(
150150
{
151151
if lit_expr.value() == &ScalarValue::UInt8(Some(1)) {
152152
return Some((
153-
ScalarValue::UInt64(Some(num_rows as u64)),
153+
ScalarValue::Int64(Some(num_rows as i64)),
154154
"COUNT(UInt8(1))",
155155
));
156156
}
@@ -183,7 +183,7 @@ fn take_optimizable_column_count(
183183
{
184184
let expr = format!("COUNT({})", col_expr.name());
185185
return Some((
186-
ScalarValue::UInt64(Some((num_rows - val) as u64)),
186+
ScalarValue::Int64(Some((num_rows - val) as i64)),
187187
expr,
188188
));
189189
}
@@ -254,7 +254,7 @@ mod tests {
254254
use super::*;
255255
use std::sync::Arc;
256256

257-
use arrow::array::{Int32Array, UInt64Array};
257+
use arrow::array::{Int32Array, Int64Array};
258258
use arrow::datatypes::{DataType, Field, Schema};
259259
use arrow::record_batch::RecordBatch;
260260

@@ -301,8 +301,8 @@ mod tests {
301301
let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?;
302302

303303
let (col, count) = match nulls {
304-
false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false), 3),
305-
true => (Field::new("COUNT(a)", DataType::UInt64, false), 2),
304+
false => (Field::new("COUNT(UInt8(1))", DataType::Int64, false), 3),
305+
true => (Field::new("COUNT(a)", DataType::Int64, false), 2),
306306
};
307307

308308
// A ProjectionExec is a sign that the count optimization was applied
@@ -313,7 +313,7 @@ mod tests {
313313
result[0]
314314
.column(0)
315315
.as_any()
316-
.downcast_ref::<UInt64Array>()
316+
.downcast_ref::<Int64Array>()
317317
.unwrap()
318318
.values(),
319319
&[count]
@@ -327,7 +327,7 @@ mod tests {
327327
None => expressions::lit(ScalarValue::UInt8(Some(1))),
328328
Some(s) => expressions::col(col.unwrap(), s).unwrap(),
329329
};
330-
Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64))
330+
Arc::new(Count::new(expr, "my_count_alias", DataType::Int64))
331331
}
332332

333333
#[tokio::test]

datafusion/core/src/physical_plan/aggregates.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ pub fn return_type(
5555
match fun {
5656
// TODO If the datafusion is compatible with PostgreSQL, the returned data type should be INT64.
5757
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
58-
Ok(DataType::UInt64)
58+
Ok(DataType::Int64)
5959
}
6060
AggregateFunction::Max | AggregateFunction::Min => {
6161
// For min and max agg function, the returned type is same as input type.
@@ -432,7 +432,7 @@ mod tests {
432432
assert!(result_agg_phy_exprs.as_any().is::<Count>());
433433
assert_eq!("c1", result_agg_phy_exprs.name());
434434
assert_eq!(
435-
Field::new("c1", DataType::UInt64, true),
435+
Field::new("c1", DataType::Int64, true),
436436
result_agg_phy_exprs.field().unwrap()
437437
);
438438
}
@@ -475,7 +475,7 @@ mod tests {
475475
assert!(result_distinct.as_any().is::<DistinctCount>());
476476
assert_eq!("c1", result_distinct.name());
477477
assert_eq!(
478-
Field::new("c1", DataType::UInt64, true),
478+
Field::new("c1", DataType::Int64, true),
479479
result_distinct.field().unwrap()
480480
);
481481
}
@@ -1121,14 +1121,14 @@ mod tests {
11211121
#[test]
11221122
fn test_count_return_type() -> Result<()> {
11231123
let observed = return_type(&AggregateFunction::Count, &[DataType::Utf8])?;
1124-
assert_eq!(DataType::UInt64, observed);
1124+
assert_eq!(DataType::Int64, observed);
11251125

11261126
let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?;
1127-
assert_eq!(DataType::UInt64, observed);
1127+
assert_eq!(DataType::Int64, observed);
11281128

11291129
let observed =
11301130
return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?;
1131-
assert_eq!(DataType::UInt64, observed);
1131+
assert_eq!(DataType::Int64, observed);
11321132
Ok(())
11331133
}
11341134

datafusion/core/src/physical_plan/window_functions.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ mod tests {
112112
fn test_count_return_type() -> Result<()> {
113113
let fun = WindowFunction::from_str("count")?;
114114
let observed = return_type(&fun, &[DataType::Utf8])?;
115-
assert_eq!(DataType::UInt64, observed);
115+
assert_eq!(DataType::Int64, observed);
116116

117117
let observed = return_type(&fun, &[DataType::UInt64])?;
118-
assert_eq!(DataType::UInt64, observed);
118+
assert_eq!(DataType::Int64, observed);
119119

120120
Ok(())
121121
}

datafusion/core/src/physical_plan/windows/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ mod tests {
235235

236236
// c3 is small int
237237

238-
let count: &UInt64Array = as_primitive_array(&columns[0]);
238+
let count: &Int64Array = as_primitive_array(&columns[0]);
239239
assert_eq!(count.value(0), 100);
240240
assert_eq!(count.value(99), 100);
241241

datafusion/core/tests/custom_sources.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{Int32Array, PrimitiveArray, UInt64Array};
18+
use arrow::array::{Int32Array, Int64Array, PrimitiveArray};
1919
use arrow::compute::kernels::aggregate;
2020
use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef};
2121
use arrow::error::Result as ArrowResult;
@@ -271,12 +271,12 @@ async fn optimizers_catch_all_statistics() {
271271

272272
let expected = RecordBatch::try_new(
273273
Arc::new(Schema::new(vec![
274-
Field::new("COUNT(UInt8(1))", DataType::UInt64, false),
274+
Field::new("COUNT(UInt8(1))", DataType::Int64, false),
275275
Field::new("MIN(test.c1)", DataType::Int32, false),
276276
Field::new("MAX(test.c1)", DataType::Int32, false),
277277
])),
278278
vec![
279-
Arc::new(UInt64Array::from_slice(&[4])),
279+
Arc::new(Int64Array::from_slice(&[4])),
280280
Arc::new(Int32Array::from_slice(&[1])),
281281
Arc::new(Int32Array::from_slice(&[100])),
282282
],

datafusion/core/tests/provider_filter_pushdown.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{as_primitive_array, Int32Builder, UInt64Array};
18+
use arrow::array::{as_primitive_array, Int32Builder, Int64Array};
1919
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
2020
use arrow::record_batch::RecordBatch;
2121
use async_trait::async_trait;
@@ -167,7 +167,7 @@ impl TableProvider for CustomProvider {
167167
}
168168
}
169169

170-
async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<()> {
170+
async fn assert_provider_row_count(value: i64, expected_count: i64) -> Result<()> {
171171
let provider = CustomProvider {
172172
zero_batch: create_batch(0, 10)?,
173173
one_batch: create_batch(1, 5)?,
@@ -180,7 +180,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<()
180180
.aggregate(vec![], vec![count(col("flag"))])?;
181181

182182
let results = df.collect().await?;
183-
let result_col: &UInt64Array = as_primitive_array(results[0].column(0));
183+
let result_col: &Int64Array = as_primitive_array(results[0].column(0));
184184
assert_eq!(result_col.value(0), expected_count);
185185

186186
ctx.register_table("data", Arc::new(provider))?;
@@ -190,7 +190,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<()
190190
.collect()
191191
.await?;
192192

193-
let sql_result_col: &UInt64Array = as_primitive_array(sql_results[0].column(0));
193+
let sql_result_col: &Int64Array = as_primitive_array(sql_results[0].column(0));
194194
assert_eq!(sql_result_col.value(0), expected_count);
195195

196196
Ok(())

datafusion/core/tests/sql/explain_analyze.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ async fn explain_analyze_baseline_metrics() {
8181
);
8282
assert_metrics!(
8383
&formatted,
84-
"ProjectionExec: expr=[CAST(COUNT(UInt8(1))",
84+
"ProjectionExec: expr=[COUNT(UInt8(1)",
8585
"metrics=[output_rows=1, elapsed_compute="
8686
);
8787
assert_metrics!(

datafusion/physical-expr/src/expressions/count.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use crate::{AggregateExpr, PhysicalExpr};
2424
use arrow::compute;
2525
use arrow::datatypes::DataType;
2626
use arrow::{
27-
array::{ArrayRef, UInt64Array},
27+
array::{ArrayRef, Int64Array},
2828
datatypes::Field,
2929
};
3030
use datafusion_common::Result;
@@ -96,7 +96,7 @@ impl AggregateExpr for Count {
9696

9797
#[derive(Debug)]
9898
struct CountAccumulator {
99-
count: u64,
99+
count: i64,
100100
}
101101

102102
impl CountAccumulator {
@@ -109,25 +109,25 @@ impl CountAccumulator {
109109
impl Accumulator for CountAccumulator {
110110
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
111111
let array = &values[0];
112-
self.count += (array.len() - array.data().null_count()) as u64;
112+
self.count += (array.len() - array.data().null_count()) as i64;
113113
Ok(())
114114
}
115115

116116
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
117-
let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap();
117+
let counts = states[0].as_any().downcast_ref::<Int64Array>().unwrap();
118118
let delta = &compute::sum(counts);
119119
if let Some(d) = delta {
120-
self.count += *d;
120+
self.count += *d as i64;
121121
}
122122
Ok(())
123123
}
124124

125125
fn state(&self) -> Result<Vec<ScalarValue>> {
126-
Ok(vec![ScalarValue::UInt64(Some(self.count))])
126+
Ok(vec![ScalarValue::Int64(Some(self.count))])
127127
}
128128

129129
fn evaluate(&self) -> Result<ScalarValue> {
130-
Ok(ScalarValue::UInt64(Some(self.count)))
130+
Ok(ScalarValue::Int64(Some(self.count)))
131131
}
132132
}
133133

@@ -148,8 +148,8 @@ mod tests {
148148
a,
149149
DataType::Int32,
150150
Count,
151-
ScalarValue::from(5u64),
152-
DataType::UInt64
151+
ScalarValue::from(5i64),
152+
DataType::Int64
153153
)
154154
}
155155

@@ -167,8 +167,8 @@ mod tests {
167167
a,
168168
DataType::Int32,
169169
Count,
170-
ScalarValue::from(3u64),
171-
DataType::UInt64
170+
ScalarValue::from(3i64),
171+
DataType::Int64
172172
)
173173
}
174174

@@ -181,8 +181,8 @@ mod tests {
181181
a,
182182
DataType::Boolean,
183183
Count,
184-
ScalarValue::from(0u64),
185-
DataType::UInt64
184+
ScalarValue::from(0i64),
185+
DataType::Int64
186186
)
187187
}
188188

@@ -194,8 +194,8 @@ mod tests {
194194
a,
195195
DataType::Boolean,
196196
Count,
197-
ScalarValue::from(0u64),
198-
DataType::UInt64
197+
ScalarValue::from(0i64),
198+
DataType::Int64
199199
)
200200
}
201201

@@ -207,8 +207,8 @@ mod tests {
207207
a,
208208
DataType::Utf8,
209209
Count,
210-
ScalarValue::from(5u64),
211-
DataType::UInt64
210+
ScalarValue::from(5i64),
211+
DataType::Int64
212212
)
213213
}
214214

@@ -220,8 +220,8 @@ mod tests {
220220
a,
221221
DataType::LargeUtf8,
222222
Count,
223-
ScalarValue::from(5u64),
224-
DataType::UInt64
223+
ScalarValue::from(5i64),
224+
DataType::Int64
225225
)
226226
}
227227
}

0 commit comments

Comments
 (0)