Skip to content

Commit 701b445

Browse files
authored
fix(query): fix register function working with nullable scalar (#17217)
* fix(query): fix register function working with nullable scalar * fix(query): fix register function working with nullable scalar * fix(query): increase pool * Update 19_0005_fuzz_cte.sh * Update mysql_source.rs * fix(query): fix register function working with nullable scalar
1 parent 41856e3 commit 701b445

File tree

5 files changed

+135
-17
lines changed

5 files changed

+135
-17
lines changed

src/query/expression/src/register_vectorize.rs

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,7 @@ pub fn passthrough_nullable_1_arg<I1: ArgType, O: ArgType>(
283283

284284
match out {
285285
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
286-
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
287-
_ => Value::Scalar(None),
286+
Value::Scalar(out) => Value::Scalar(Some(out)),
288287
}
289288
}
290289
_ => Value::Scalar(None),
@@ -308,15 +307,15 @@ pub fn passthrough_nullable_2_arg<I1: ArgType, I2: ArgType, O: ArgType>(
308307
if let Some(validity) = ctx.validity.as_ref() {
309308
args_validity = &args_validity & validity;
310309
}
310+
311311
ctx.validity = Some(args_validity.clone());
312312
match (arg1.value(), arg2.value()) {
313313
(Some(arg1), Some(arg2)) => {
314314
let out = func(arg1, arg2, ctx);
315315

316316
match out {
317317
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
318-
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
319-
_ => Value::Scalar(None),
318+
Value::Scalar(out) => Value::Scalar(Some(out)),
320319
}
321320
}
322321
_ => Value::Scalar(None),
@@ -352,8 +351,7 @@ pub fn passthrough_nullable_3_arg<I1: ArgType, I2: ArgType, I3: ArgType, O: ArgT
352351

353352
match out {
354353
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
355-
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
356-
_ => Value::Scalar(None),
354+
Value::Scalar(out) => Value::Scalar(Some(out)),
357355
}
358356
}
359357
_ => Value::Scalar(None),
@@ -397,8 +395,7 @@ pub fn passthrough_nullable_4_arg<
397395

398396
match out {
399397
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
400-
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
401-
_ => Value::Scalar(None),
398+
Value::Scalar(out) => Value::Scalar(Some(out)),
402399
}
403400
}
404401
_ => Value::Scalar(None),
@@ -427,8 +424,7 @@ pub fn combine_nullable_1_arg<I1: ArgType, O: ArgType>(
427424
out.column,
428425
&args_validity & &out.validity,
429426
)),
430-
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
431-
_ => Value::Scalar(None),
427+
Value::Scalar(out) => Value::Scalar(out),
432428
}
433429
}
434430
_ => Value::Scalar(None),
@@ -465,8 +461,7 @@ pub fn combine_nullable_2_arg<I1: ArgType, I2: ArgType, O: ArgType>(
465461
out.column,
466462
&args_validity & &out.validity,
467463
)),
468-
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
469-
_ => Value::Scalar(None),
464+
Value::Scalar(out) => Value::Scalar(out),
470465
}
471466
}
472467
_ => Value::Scalar(None),
@@ -505,8 +500,7 @@ pub fn combine_nullable_3_arg<I1: ArgType, I2: ArgType, I3: ArgType, O: ArgType>
505500
out.column,
506501
&args_validity & &out.validity,
507502
)),
508-
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
509-
_ => Value::Scalar(None),
503+
Value::Scalar(out) => Value::Scalar(out),
510504
}
511505
}
512506
_ => Value::Scalar(None),
@@ -552,8 +546,7 @@ pub fn combine_nullable_4_arg<I1: ArgType, I2: ArgType, I3: ArgType, I4: ArgType
552546
out.column,
553547
&args_validity & &out.validity,
554548
)),
555-
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
556-
_ => Value::Scalar(None),
549+
Value::Scalar(out) => Value::Scalar(out),
557550
}
558551
}
559552
_ => Value::Scalar(None),

src/query/functions/tests/it/scalars/mod.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,43 @@ fn list_all_builtin_functions() {
271271
fn check_ambiguity() {
272272
BUILTIN_FUNCTIONS.check_ambiguity()
273273
}
274+
275+
#[test]
276+
fn test_if_function() -> Result<()> {
277+
use databend_common_expression::types::*;
278+
use databend_common_expression::FromData;
279+
use databend_common_expression::Scalar;
280+
let raw_expr = parser::parse_raw_expr("if(eq(n,1), sum_sid + 1,100)", &[
281+
("n", UInt8Type::data_type()),
282+
("sum_sid", Int32Type::data_type().wrap_nullable()),
283+
]);
284+
let expr = type_check::check(&raw_expr, &BUILTIN_FUNCTIONS)?;
285+
let block = DataBlock::new(
286+
vec![
287+
BlockEntry {
288+
data_type: UInt8Type::data_type(),
289+
value: Value::Column(UInt8Type::from_data(vec![2_u8, 1])),
290+
},
291+
BlockEntry {
292+
data_type: Int32Type::data_type().wrap_nullable(),
293+
value: Value::Scalar(Scalar::Number(NumberScalar::Int32(2400_i32))),
294+
},
295+
],
296+
2,
297+
);
298+
let func_ctx = FunctionContext::default();
299+
let evaluator = Evaluator::new(&block, &func_ctx, &BUILTIN_FUNCTIONS);
300+
let result = evaluator.run(&expr).unwrap();
301+
let result = result
302+
.as_column()
303+
.unwrap()
304+
.clone()
305+
.as_nullable()
306+
.unwrap()
307+
.clone();
308+
309+
let bm = Bitmap::from_iter([true, true]);
310+
assert_eq!(result.validity, bm);
311+
assert_eq!(result.column, Int64Type::from_data(vec![100, 2401]));
312+
Ok(())
313+
}

tests/sqllogictests/suites/query/cte/basic_r_cte.test

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,5 +227,64 @@ select cte1.a from cte1;
227227
8
228228
9
229229

230+
231+
statement ok
232+
create table train(
233+
train_id varchar(8) not null ,
234+
departure_station varchar(32) not null,
235+
arrival_station varchar(32) not null,
236+
seat_count int not null
237+
);
238+
239+
statement ok
240+
create table passenger(
241+
passenger_id varchar(16) not null,
242+
departure_station varchar(32) not null,
243+
arrival_station varchar(32) not null
244+
);
245+
246+
statement ok
247+
create table city(city varchar(32));
248+
249+
statement ok
250+
insert into city
251+
with t as (select 1 n union select 2 union select 3 union select 4 union select 5)
252+
,t1 as(select row_number()over() rn from t ,t t2,t t3)
253+
select concat('城市',rn::varchar) city from t1 where rn<=5;
254+
255+
statement ok
256+
insert into train
257+
select concat('G',row_number()over()::varchar),c1.city,c2.city, n from city c1, city c2, (select 600 n union select 800 union select 1200 union select 1600) a ;
258+
259+
statement ok
260+
insert into passenger
261+
select concat('P',substr((100000000+row_number()over())::varchar,2)),c1.city,c2.city from city c1, city c2 ,city c3, city c4, city c5,
262+
city c6, (select 1 n union select 2 union select 3 union select 4) c7,(select 1 n union select 2) c8;
263+
264+
265+
query III
266+
with
267+
t0 as (
268+
select
269+
train_id,
270+
seat_count,
271+
sum(seat_count) over (
272+
partition by departure_station, arrival_station order by train_id
273+
) ::int sum_sid
274+
from
275+
train
276+
)
277+
select
278+
sum(case when n=1 then sum_sid+1 else 0 end::int),
279+
sum(sum_sid),
280+
sum(seat_count)
281+
from
282+
t0,(select 1 n union all select 2);
283+
----
284+
261700 523200 210000
285+
286+
statement ok
287+
use default;
288+
230289
statement ok
231-
drop table t1;
290+
drop database db;
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
OK
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/usr/bin/env bash
2+
3+
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
4+
. "$CURDIR"/../../../shell_env.sh
5+
6+
7+
times=256
8+
9+
echo "" > /tmp/fuzz_a.txt
10+
echo "" > /tmp/fuzz_b.txt
11+
12+
for i in `seq 1 ${times}`;do
13+
echo """with t0(sum_sid) as (select sum(number) over(partition by number order by number)
14+
from numbers(3)) select n, if(n =1, sum_sid +1, 0) from t0, (select 1 n union all select 2) order by 1,2;
15+
""" | $BENDSQL_CLIENT_CONNECT >> /tmp/fuzz_a.txt
16+
done
17+
18+
19+
for i in `seq 1 ${times}`;do
20+
echo """with t0(sum_sid) as (select sum(number) over(partition by number order by number)
21+
from numbers(3)) select n, if(n =1, sum_sid +1, 0) from t0, (select 1 n union all select 2) order by 1,2;
22+
""" | $BENDSQL_CLIENT_CONNECT >> /tmp/fuzz_b.txt
23+
done
24+
25+
diff /tmp/fuzz_a.txt /tmp/fuzz_b.txt && echo "OK"

0 commit comments

Comments
 (0)