Skip to content

Commit 99daafd

Browse files
authored
feat(spark): implement Spark conditional function if (apache#16946)
1 parent 25acb64 commit 99daafd

File tree

3 files changed

+255
-6
lines changed
  • datafusion
    • spark/src/function/conditional
    • sqllogictest/test_files/spark/conditional

3 files changed

+255
-6
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::datatypes::DataType;
19+
use datafusion_common::{internal_err, plan_err, Result};
20+
use datafusion_expr::{
21+
binary::try_type_union_resolution, simplify::ExprSimplifyResult, when, ColumnarValue,
22+
Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
23+
};
24+
25+
#[derive(Debug, PartialEq, Eq, Hash)]
26+
pub struct SparkIf {
27+
signature: Signature,
28+
}
29+
30+
impl Default for SparkIf {
31+
fn default() -> Self {
32+
Self::new()
33+
}
34+
}
35+
36+
impl SparkIf {
37+
pub fn new() -> Self {
38+
Self {
39+
signature: Signature::user_defined(Volatility::Immutable),
40+
}
41+
}
42+
}
43+
44+
impl ScalarUDFImpl for SparkIf {
45+
fn as_any(&self) -> &dyn std::any::Any {
46+
self
47+
}
48+
49+
fn name(&self) -> &str {
50+
"if"
51+
}
52+
53+
fn signature(&self) -> &Signature {
54+
&self.signature
55+
}
56+
57+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
58+
if arg_types.len() != 3 {
59+
return plan_err!(
60+
"Function 'if' expects 3 arguments but received {}",
61+
arg_types.len()
62+
);
63+
}
64+
65+
if arg_types[0] != DataType::Boolean && arg_types[0] != DataType::Null {
66+
return plan_err!(
67+
"For function 'if' {} is not a boolean or null",
68+
arg_types[0]
69+
);
70+
}
71+
72+
let target_types = try_type_union_resolution(&arg_types[1..])?;
73+
let mut result = vec![DataType::Boolean];
74+
result.extend(target_types);
75+
Ok(result)
76+
}
77+
78+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
79+
Ok(arg_types[1].clone())
80+
}
81+
82+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
83+
internal_err!("if should have been simplified to case")
84+
}
85+
86+
fn simplify(
87+
&self,
88+
args: Vec<Expr>,
89+
_info: &dyn datafusion_expr::simplify::SimplifyInfo,
90+
) -> Result<ExprSimplifyResult> {
91+
let condition = args[0].clone();
92+
let then_expr = args[1].clone();
93+
let else_expr = args[2].clone();
94+
95+
// Convert IF(condition, then_expr, else_expr) to
96+
// CASE WHEN condition THEN then_expr ELSE else_expr END
97+
let case_expr = when(condition, then_expr).otherwise(else_expr)?;
98+
99+
Ok(ExprSimplifyResult::Simplified(case_expr))
100+
}
101+
}

datafusion/spark/src/function/conditional/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,19 @@
1616
// under the License.
1717

1818
use datafusion_expr::ScalarUDF;
19+
use datafusion_functions::make_udf_function;
1920
use std::sync::Arc;
2021

21-
pub mod expr_fn {}
22+
mod r#if;
23+
24+
make_udf_function!(r#if::SparkIf, r#if);
25+
26+
pub mod expr_fn {
27+
use datafusion_functions::export_functions;
28+
29+
export_functions!((r#if, "If arg1 evaluates to true, then returns arg2; otherwise returns arg3", arg1 arg2 arg3));
30+
}
2231

2332
pub fn functions() -> Vec<Arc<ScalarUDF>> {
24-
vec![]
33+
vec![r#if()]
2534
}

datafusion/sqllogictest/test_files/spark/conditional/if.slt

Lines changed: 143 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,146 @@
2121
# For more information, please see:
2222
# https://github.com/apache/datafusion/issues/15914
2323

24-
## Original Query: SELECT if(1 < 2, 'a', 'b');
25-
## PySpark 3.5.5 Result: {'(IF((1 < 2), a, b))': 'a', 'typeof((IF((1 < 2), a, b)))': 'string', 'typeof((1 < 2))': 'boolean', 'typeof(a)': 'string', 'typeof(b)': 'string'}
26-
#query
27-
#SELECT if((1 < 2)::boolean, 'a'::string, 'b'::string);
24+
## Basic IF function tests
25+
26+
# Test basic true condition
27+
query T
28+
SELECT if(true, 'yes', 'no');
29+
----
30+
yes
31+
32+
# Test basic false condition
33+
query T
34+
SELECT if(false, 'yes', 'no');
35+
----
36+
no
37+
38+
# Test with comparison operators
39+
query T
40+
SELECT if(1 < 2, 'a', 'b');
41+
----
42+
a
43+
44+
query T
45+
SELECT if(1 > 2, 'a', 'b');
46+
----
47+
b
48+
49+
50+
## Numeric type tests
51+
52+
# Test with integers
53+
query I
54+
SELECT if(true, 10, 20);
55+
----
56+
10
57+
58+
query I
59+
SELECT if(false, 10, 20);
60+
----
61+
20
62+
63+
# Test with different integer types
64+
query I
65+
SELECT if(true, 100, 200);
66+
----
67+
100
68+
69+
## Float type tests
70+
71+
# Test with floating point numbers
72+
query R
73+
SELECT if(true, 1.5, 2.5);
74+
----
75+
1.5
76+
77+
query R
78+
SELECT if(false, 1.5, 2.5);
79+
----
80+
2.5
81+
82+
## String type tests
83+
84+
# Test with different string values
85+
query T
86+
SELECT if(true, 'hello', 'world');
87+
----
88+
hello
89+
90+
query T
91+
SELECT if(false, 'hello', 'world');
92+
----
93+
world
94+
95+
## NULL handling tests
96+
97+
# Test with NULL condition
98+
query T
99+
SELECT if(NULL, 'yes', 'no');
100+
----
101+
no
102+
103+
query T
104+
SELECT if(NOT NULL, 'yes', 'no');
105+
----
106+
no
107+
108+
# Test with NULL true value
109+
query T
110+
SELECT if(true, NULL, 'no');
111+
----
112+
NULL
113+
114+
# Test with NULL false value
115+
query T
116+
SELECT if(false, 'yes', NULL);
117+
----
118+
NULL
119+
120+
# Test with all NULL
121+
query ?
122+
SELECT if(true, NULL, NULL);
123+
----
124+
NULL
125+
126+
## Type coercion tests
127+
128+
# Test integer to float coercion
129+
query R
130+
SELECT if(true, 10, 20.5);
131+
----
132+
10
133+
134+
query R
135+
SELECT if(false, 10, 20.5);
136+
----
137+
20.5
138+
139+
# Test float to integer coercion
140+
query R
141+
SELECT if(true, 10.5, 20);
142+
----
143+
10.5
144+
145+
query R
146+
SELECT if(false, 10.5, 20);
147+
----
148+
20
149+
150+
statement error Int64 is not a boolean or null
151+
SELECT if(1, 10.5, 20);
152+
153+
154+
statement error Utf8 is not a boolean or null
155+
SELECT if('x', 10.5, 20);
156+
157+
query II
158+
SELECT v, IF(v < 0, 10/0, 1) FROM (VALUES (1), (2)) t(v)
159+
----
160+
1 1
161+
2 1
162+
163+
query I
164+
SELECT IF(true, 1 / 1, 1 / 0);
165+
----
166+
1

0 commit comments

Comments
 (0)