Skip to content

Commit 3d9cbd5

Browse files
authored
fix: pass scale to DF round in spark_round (#1341)
## Which issue does this PR close? Closes #1340.
1 parent 6cf140f commit 3d9cbd5

File tree

1 file changed

+82
-4
lines changed
  • native/spark-expr/src/math_funcs

1 file changed

+82
-4
lines changed

native/spark-expr/src/math_funcs/round.rs

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ pub fn spark_round(
8585
let (precision, scale) = get_precision_scale(data_type);
8686
make_decimal_array(array, precision, scale, &f)
8787
}
88-
DataType::Float32 | DataType::Float64 => {
89-
Ok(ColumnarValue::Array(round(&[Arc::clone(array)])?))
90-
}
88+
DataType::Float32 | DataType::Float64 => Ok(ColumnarValue::Array(round(&[
89+
Arc::clone(array),
90+
args[1].to_array(array.len())?,
91+
])?)),
9192
dt => exec_err!("Not supported datatype for ROUND: {dt}"),
9293
},
9394
ColumnarValue::Scalar(a) => match a {
@@ -109,7 +110,7 @@ pub fn spark_round(
109110
make_decimal_scalar(a, precision, scale, &f)
110111
}
111112
ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar(
112-
ScalarValue::try_from_array(&round(&[a.to_array()?])?, 0)?,
113+
ScalarValue::try_from_array(&round(&[a.to_array()?, args[1].to_array(1)?])?, 0)?,
113114
)),
114115
dt => exec_err!("Not supported datatype for ROUND: {dt}"),
115116
},
@@ -135,3 +136,80 @@ fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> {
135136
Box::new(move |x: i128| (x + x.signum() * half) / div)
136137
}
137138
}
139+
140+
#[cfg(test)]
141+
mod test {
142+
use std::sync::Arc;
143+
144+
use crate::spark_round;
145+
146+
use arrow::array::{Float32Array, Float64Array};
147+
use arrow_schema::DataType;
148+
use datafusion_common::cast::{as_float32_array, as_float64_array};
149+
use datafusion_common::{Result, ScalarValue};
150+
use datafusion_expr::ColumnarValue;
151+
152+
#[test]
153+
fn test_round_f32_array() -> Result<()> {
154+
let args = vec![
155+
ColumnarValue::Array(Arc::new(Float32Array::from(vec![
156+
125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
157+
]))),
158+
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
159+
];
160+
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32)? else {
161+
unreachable!()
162+
};
163+
let floats = as_float32_array(&result)?;
164+
let expected = Float32Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
165+
assert_eq!(floats, &expected);
166+
Ok(())
167+
}
168+
169+
#[test]
170+
fn test_round_f64_array() -> Result<()> {
171+
let args = vec![
172+
ColumnarValue::Array(Arc::new(Float64Array::from(vec![
173+
125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
174+
]))),
175+
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
176+
];
177+
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64)? else {
178+
unreachable!()
179+
};
180+
let floats = as_float64_array(&result)?;
181+
let expected = Float64Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
182+
assert_eq!(floats, &expected);
183+
Ok(())
184+
}
185+
186+
#[test]
187+
fn test_round_f32_scalar() -> Result<()> {
188+
let args = vec![
189+
ColumnarValue::Scalar(ScalarValue::Float32(Some(125.2345))),
190+
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
191+
];
192+
let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) =
193+
spark_round(&args, &DataType::Float32)?
194+
else {
195+
unreachable!()
196+
};
197+
assert_eq!(result, 125.23);
198+
Ok(())
199+
}
200+
201+
#[test]
202+
fn test_round_f64_scalar() -> Result<()> {
203+
let args = vec![
204+
ColumnarValue::Scalar(ScalarValue::Float64(Some(125.2345))),
205+
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
206+
];
207+
let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) =
208+
spark_round(&args, &DataType::Float64)?
209+
else {
210+
unreachable!()
211+
};
212+
assert_eq!(result, 125.23);
213+
Ok(())
214+
}
215+
}

0 commit comments

Comments
 (0)