Skip to content

Commit 5d8249f

Browse files
fix: Fix and Refactor Spark shuffle function (#20484)
## Which issue does this PR close? - Closes #20483. ## Rationale for this change Currently, Spark `shuffle` function returns following error message when `seed` is `null`. This needs to be fixed by exposing `NULL` instead of `'Int64'`. **Current:** ``` query error SELECT shuffle([2, 1], NULL); ---- DataFusion error: Execution error: shuffle seed must be Int64 type, got 'Int64' ``` **New:** ``` query error DataFusion error: Execution error: shuffle seed must be Int64 type but got 'NULL' SELECT shuffle([1, 2, 3], NULL); ``` In addition to this fix, this PR also introduces following refactoring to `shuffle` function: - Combining args validation checks with `single` error message, - Extending current error message with expected data types: ``` Current: shuffle does not support type '{array_type}'. New: shuffle does not support type '{array_type}'; expected types: List, LargeList, FixedSizeList or Null." ``` - Adding new UT coverages for both `shuffle.rs` and `shuffle.slt`. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes, being added new UT cases. ## Are there any user-facing changes? Yes, updating Spark `shuffle` functions error messages.
1 parent e567cb9 commit 5d8249f

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

datafusion/spark/src/function/array/shuffle.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,8 @@ impl ScalarUDFImpl for SparkShuffle {
105105
&self,
106106
args: datafusion_expr::ScalarFunctionArgs,
107107
) -> Result<ColumnarValue> {
108-
if args.args.is_empty() {
109-
return exec_err!("shuffle expects at least 1 argument");
110-
}
111-
if args.args.len() > 2 {
112-
return exec_err!("shuffle expects at most 2 arguments");
108+
if args.args.is_empty() || args.args.len() > 2 {
109+
return exec_err!("shuffle expects 1 or 2 argument(s)");
113110
}
114111

115112
// Extract seed from second argument if present
@@ -131,10 +128,10 @@ fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
131128
ColumnarValue::Scalar(scalar) => {
132129
let seed = match scalar {
133130
ScalarValue::Int64(Some(v)) => Some(*v as u64),
134-
ScalarValue::Null => None,
131+
ScalarValue::Null | ScalarValue::Int64(None) => None,
135132
_ => {
136133
return exec_err!(
137-
"shuffle seed must be Int64 type, got '{}'",
134+
"shuffle seed must be Int64 type but got '{}'",
138135
scalar.data_type()
139136
);
140137
}
@@ -164,7 +161,10 @@ fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> Result<ArrayR
164161
fixed_size_array_shuffle(array, field, seed)
165162
}
166163
Null => Ok(Arc::clone(input_array)),
167-
array_type => exec_err!("shuffle does not support type '{array_type}'."),
164+
array_type => exec_err!(
165+
"shuffle does not support type '{array_type}'; \
166+
expected types: List, LargeList, FixedSizeList or Null."
167+
),
168168
}
169169
}
170170

datafusion/sqllogictest/test_files/spark/array/shuffle.slt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ SELECT shuffle([1, 2, 3, 4], CAST('2' AS INT));
107107
----
108108
[1, 4, 2, 3]
109109

110+
query ?
111+
SELECT shuffle(['ab'], NULL);
112+
----
113+
[ab]
114+
115+
query ?
116+
SELECT shuffle(shuffle([3, 3], NULL), NULL);
117+
----
118+
[3, 3]
119+
110120
# Clean up
111121
statement ok
112122
DROP TABLE test_shuffle_list_types;

0 commit comments

Comments
 (0)