Skip to content

Commit ef01c31

Browse files
authored
fix(expr): unify the behavior on nullalbe for unnest() (#10797)
* fix(expr): unify the behavior on nullalbe for unnest() * fix * fix and_filters * fix * fix * fix
1 parent 68c1e85 commit ef01c31

File tree

8 files changed

+308
-489
lines changed

8 files changed

+308
-489
lines changed

src/query/expression/src/evaluator.rs

Lines changed: 101 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,10 @@ impl<'a> Evaluator<'a> {
157157
..
158158
} if function.signature.name == "if" => self.eval_if(args, generics, validity),
159159

160-
Expr::FunctionCall {
161-
function,
162-
args,
163-
generics,
164-
..
165-
} if function.signature.name == "and_filters" => {
166-
self.eval_and_filters(args, generics, validity)
160+
Expr::FunctionCall { function, args, .. }
161+
if function.signature.name == "and_filters" =>
162+
{
163+
self.eval_and_filters(args, validity)
167164
}
168165

169166
Expr::FunctionCall {
@@ -812,11 +809,10 @@ impl<'a> Evaluator<'a> {
812809
}
813810
}
814811

815-
// `and_filters` is a special builtin function similar to `if` that could partially
812+
// `and_filters` is a special builtin function similar to `if` that conditionally evaluate its arguments.
816813
fn eval_and_filters(
817814
&self,
818815
args: &[Expr],
819-
_: &[DataType],
820816
mut validity: Option<Bitmap>,
821817
) -> Result<Value<AnyType>> {
822818
assert!(args.len() >= 2);
@@ -850,7 +846,8 @@ impl<'a> Evaluator<'a> {
850846
}
851847
}
852848

853-
/// Evaluate a set returning function. Return multiple chunks of results, and the repeat times of each of the result.
849+
/// Evaluate a set-returning-function. Return multiple sets of results
850+
/// for each input row, along with the number of rows in each set.
854851
pub fn run_srf(&self, expr: &Expr) -> Result<Vec<(Value<AnyType>, usize)>> {
855852
if let Expr::FunctionCall {
856853
function,
@@ -866,7 +863,9 @@ impl<'a> Evaluator<'a> {
866863
.map(|expr| self.run(expr))
867864
.collect::<Result<Vec<_>>>()?;
868865
let cols_ref = args.iter().map(Value::as_ref).collect::<Vec<_>>();
869-
return Ok((eval)(&cols_ref, self.input_columns.num_rows()));
866+
let result = (eval)(&cols_ref, self.input_columns.num_rows());
867+
assert_eq!(result.len(), self.input_columns.num_rows());
868+
return Ok(result);
870869
}
871870
}
872871

@@ -1005,7 +1004,7 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
10051004
scalar,
10061005
data_type: dest_type.clone(),
10071006
},
1008-
new_domain,
1007+
None,
10091008
);
10101009
}
10111010
}
@@ -1032,68 +1031,109 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
10321031
return_type,
10331032
} if function.signature.name == "and_filters" => {
10341033
let mut args_expr = Vec::new();
1035-
let mut has_true = true;
1036-
let mut has_false = true;
1034+
let mut result_domain = Some(BooleanDomain {
1035+
has_true: true,
1036+
has_false: true,
1037+
});
10371038

10381039
type DomainType = NullableType<BooleanType>;
10391040
for arg in args {
10401041
let (expr, domain) = self.fold_once(arg);
1042+
// A temporary hack to make `and_filters` shortcut on false.
1043+
// TODO(andylokandy): make it a rule in the optimizer.
1044+
if let Expr::Constant {
1045+
scalar: Scalar::Boolean(false),
1046+
..
1047+
} = &expr
1048+
{
1049+
return (
1050+
Expr::Constant {
1051+
span: *span,
1052+
scalar: Scalar::Boolean(false),
1053+
data_type: DataType::Boolean,
1054+
},
1055+
None,
1056+
);
1057+
}
10411058
args_expr.push(expr);
10421059

1043-
match domain {
1044-
Some(domain) => {
1045-
let domain = DomainType::try_downcast_domain(&domain).unwrap();
1046-
let (domain_hash_true, domain_hash_false) = match &domain {
1047-
NullableDomain {
1048-
has_null,
1049-
value:
1050-
Some(box BooleanDomain {
1051-
has_true,
1052-
has_false,
1053-
}),
1054-
} => (*has_true, *has_null || *has_false),
1055-
NullableDomain { value: None, .. } => (false, true),
1056-
};
1057-
1058-
has_true = has_true && domain_hash_true;
1059-
has_false = has_false || domain_hash_false;
1060-
}
1061-
None => {
1062-
continue;
1060+
result_domain = result_domain.zip(domain).map(|(func_domain, domain)| {
1061+
let domain = DomainType::try_downcast_domain(&domain).unwrap();
1062+
let (domain_has_true, domain_has_false) = match &domain {
1063+
NullableDomain {
1064+
has_null,
1065+
value:
1066+
Some(box BooleanDomain {
1067+
has_true,
1068+
has_false,
1069+
}),
1070+
} => (*has_true, *has_null || *has_false),
1071+
NullableDomain { value: None, .. } => (false, true),
1072+
};
1073+
BooleanDomain {
1074+
has_true: func_domain.has_true && domain_has_true,
1075+
has_false: func_domain.has_false || domain_has_false,
10631076
}
1077+
});
1078+
1079+
if let Some(Scalar::Boolean(false)) = result_domain
1080+
.as_ref()
1081+
.and_then(|domain| Domain::Boolean(*domain).as_singleton())
1082+
{
1083+
return (
1084+
Expr::Constant {
1085+
span: *span,
1086+
scalar: Scalar::Boolean(false),
1087+
data_type: DataType::Boolean,
1088+
},
1089+
None,
1090+
);
10641091
}
10651092
}
10661093

1067-
if !has_true && has_false {
1068-
(
1094+
if let Some(scalar) = result_domain
1095+
.as_ref()
1096+
.and_then(|domain| Domain::Boolean(*domain).as_singleton())
1097+
{
1098+
return (
10691099
Expr::Constant {
10701100
span: *span,
1071-
scalar: Scalar::Boolean(false),
1101+
scalar,
10721102
data_type: DataType::Boolean,
10731103
},
1074-
Some(Domain::Boolean(BooleanDomain {
1075-
has_true: false,
1076-
has_false: true,
1077-
})),
1078-
)
1079-
} else {
1080-
let func_expr = Expr::FunctionCall {
1081-
span: *span,
1082-
id: id.clone(),
1083-
function: function.clone(),
1084-
generics: generics.clone(),
1085-
args: args_expr,
1086-
return_type: return_type.clone(),
1087-
};
1088-
1089-
(
1090-
func_expr,
1091-
Some(Domain::Boolean(BooleanDomain {
1092-
has_true,
1093-
has_false,
1094-
})),
1095-
)
1104+
None,
1105+
);
10961106
}
1107+
1108+
let all_args_is_scalar = args_expr.iter().all(|arg| arg.as_constant().is_some());
1109+
1110+
let func_expr = Expr::FunctionCall {
1111+
span: *span,
1112+
id: id.clone(),
1113+
function: function.clone(),
1114+
generics: generics.clone(),
1115+
args: args_expr,
1116+
return_type: return_type.clone(),
1117+
};
1118+
1119+
if all_args_is_scalar {
1120+
let block = DataBlock::empty();
1121+
let evaluator = Evaluator::new(&block, self.func_ctx, self.fn_registry);
1122+
// Since we know the expression is constant, it'll be safe to change its column index type.
1123+
let func_expr = func_expr.project_column_ref(|_| unreachable!());
1124+
if let Ok(Value::Scalar(scalar)) = evaluator.run(&func_expr) {
1125+
return (
1126+
Expr::Constant {
1127+
span: *span,
1128+
scalar,
1129+
data_type: return_type.clone(),
1130+
},
1131+
None,
1132+
);
1133+
}
1134+
}
1135+
1136+
(func_expr, result_domain.map(Domain::Boolean))
10971137
}
10981138
Expr::FunctionCall {
10991139
span,
@@ -1143,7 +1183,7 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
11431183
scalar,
11441184
data_type: return_type.clone(),
11451185
},
1146-
func_domain,
1186+
None,
11471187
);
11481188
}
11491189

@@ -1159,7 +1199,7 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> {
11591199
scalar,
11601200
data_type: return_type.clone(),
11611201
},
1162-
func_domain,
1202+
None,
11631203
);
11641204
}
11651205
}

src/query/expression/src/function.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ pub type AutoCastRules<'a> = &'a [(DataType, DataType)];
4747
/// A function to build function depending on the const parameters and the type of arguments (before coercion).
4848
///
4949
/// The first argument is the const parameters and the second argument is the types of arguments.
50-
pub type FunctionFactory =
51-
Box<dyn Fn(&[usize], &[DataType]) -> Option<Arc<Function>> + Send + Sync + 'static>;
50+
pub trait FunctionFactory =
51+
Fn(&[usize], &[DataType]) -> Option<Arc<Function>> + Send + Sync + 'static;
5252

5353
pub struct Function {
5454
pub signature: FunctionSignature,
@@ -73,9 +73,10 @@ pub enum FunctionEval {
7373
/// The result must be in the same length as the input arguments if its a column.
7474
eval: Box<dyn Fn(&[ValueRef<AnyType>], &mut EvalContext) -> Value<AnyType> + Send + Sync>,
7575
},
76-
/// Set returning function that returns a series of values.
76+
/// Set-returning-function that input a scalar and then return a set.
7777
SRF {
78-
/// Given a set of arguments, return a series of chunks of result and the repeat time of each chunk.
78+
/// Given multiple rows, return multiple sets of results
79+
/// for each input row, along with the number of rows in each set.
7980
eval:
8081
Box<dyn Fn(&[ValueRef<AnyType>], usize) -> Vec<(Value<AnyType>, usize)> + Send + Sync>,
8182
},
@@ -118,7 +119,8 @@ pub enum FunctionID {
118119
#[derive(Default)]
119120
pub struct FunctionRegistry {
120121
pub funcs: HashMap<String, Vec<(Arc<Function>, usize)>>,
121-
pub factories: HashMap<String, Vec<(FunctionFactory, usize)>>,
122+
#[allow(clippy::type_complexity)]
123+
pub factories: HashMap<String, Vec<(Box<dyn FunctionFactory>, usize)>>,
122124

123125
/// Aliases map from alias function name to original function name.
124126
pub aliases: HashMap<String, String>,
@@ -288,11 +290,7 @@ impl FunctionRegistry {
288290
.push((Arc::new(func), id));
289291
}
290292

291-
pub fn register_function_factory(
292-
&mut self,
293-
name: &str,
294-
factory: impl Fn(&[usize], &[DataType]) -> Option<Arc<Function>> + 'static + Send + Sync,
295-
) {
293+
pub fn register_function_factory(&mut self, name: &str, factory: impl FunctionFactory) {
296294
let id = self.next_function_id(name);
297295
self.factories
298296
.entry(name.to_string())

src/query/expression/src/type_check.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,13 @@ pub fn try_unify_signature(
361361
dest_tys: impl IntoIterator<Item = &DataType> + ExactSizeIterator,
362362
auto_cast_rules: AutoCastRules,
363363
) -> Result<Substitution> {
364-
assert_eq!(src_tys.len(), dest_tys.len());
364+
if src_tys.len() != dest_tys.len() {
365+
return Err(ErrorCode::from_string_no_backtrace(format!(
366+
"expected {} arguments, got {}",
367+
dest_tys.len(),
368+
src_tys.len()
369+
)));
370+
}
365371

366372
let substs = src_tys
367373
.into_iter()

src/query/expression/src/values.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -885,14 +885,6 @@ impl Column {
885885
}
886886
}
887887

888-
/// Unnest a nested column into one column.
889-
pub fn unnest(&self) -> Self {
890-
match self {
891-
Column::Array(array) => array.underlying_column().unnest(),
892-
col => col.clone(),
893-
}
894-
}
895-
896888
pub fn arrow_field(&self) -> common_arrow::arrow::datatypes::Field {
897889
use common_arrow::arrow::datatypes::DataType as ArrowDataType;
898890
use common_arrow::arrow::datatypes::Field as ArrowField;

0 commit comments

Comments
 (0)