Skip to content

Commit 937894f

Browse files
authored
fix(query): Prevent invalid filter expression generation in InferFilterOptimizer (#17929)
fix(query): fix infer filter transitive equality with different column types
1 parent fc3b25e commit 937894f

File tree

4 files changed

+170
-9
lines changed

4 files changed

+170
-9
lines changed

src/query/service/tests/it/sql/planner/optimizer/optimizers/operator/filter/infer_filter_test.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,7 @@ fn test_different_data_types() -> Result<()> {
970970
"",
971971
0,
972972
);
973+
let col_variant = builder.column("variant", 3, "variant", DataType::Variant, "", 0);
973974
let _col_a = builder.column("A", 3, "A", DataType::Number(NumberDataType::Int64), "", 0);
974975

975976
// Test: int8 column with values at type boundaries
@@ -1023,6 +1024,19 @@ fn test_different_data_types() -> Result<()> {
10231024
);
10241025
}
10251026

1027+
// Test: mixing integer and variant types
1028+
{
1029+
let const_5_int = builder.int(5);
1030+
1031+
// Test: variant = int8 AND variant = 5
1032+
let pred_variant_eq_int = builder.eq(col_variant.clone(), col_int8.clone());
1033+
let pred_variant_eq_5 = builder.eq(col_variant.clone(), const_5_int.clone());
1034+
1035+
let result = run_optimizer(vec![pred_variant_eq_int, pred_variant_eq_5])?;
1036+
1037+
assert_eq!(result.len(), 2, "Shouldn't add transitive equality");
1038+
}
1039+
10261040
// Different data type not work yet, need fix.
10271041
// Test: mixing integer and float types
10281042
// {

src/query/sql/src/planner/optimizer/optimizers/operator/filter/infer_filter.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ impl<'a> InferFilterOptimizer<'a> {
9898
) {
9999
(true, true) => {
100100
if op == ComparisonOp::Equal {
101-
self.add_equal_expr(&func.arguments[0], &func.arguments[1]);
101+
if !self.add_equal_expr(&func.arguments[0], &func.arguments[1]) {
102+
remaining_predicates.push(predicate);
103+
}
102104
} else {
103105
remaining_predicates.push(predicate);
104106
}
@@ -178,7 +180,16 @@ impl<'a> InferFilterOptimizer<'a> {
178180
self.expr_equal_to.push(expr_equal_to);
179181
}
180182

181-
pub fn add_equal_expr(&mut self, left: &ScalarExpr, right: &ScalarExpr) {
183+
pub fn add_equal_expr(&mut self, left: &ScalarExpr, right: &ScalarExpr) -> bool {
184+
let Ok(left_ty) = left.data_type() else {
185+
return false;
186+
};
187+
let Ok(right_ty) = right.data_type() else {
188+
return false;
189+
};
190+
if !Self::check_equal_expr_type(&left_ty, &right_ty) {
191+
return false;
192+
}
182193
match self.expr_index.get(left) {
183194
Some(index) => self.expr_equal_to[*index].push(right.clone()),
184195
None => self.add_expr(left, vec![], vec![right.clone()]),
@@ -188,6 +199,28 @@ impl<'a> InferFilterOptimizer<'a> {
188199
Some(index) => self.expr_equal_to[*index].push(left.clone()),
189200
None => self.add_expr(right, vec![], vec![left.clone()]),
190201
};
202+
203+
true
204+
}
205+
206+
// equal expr must have the same type, otherwise the function may fail on execution.
207+
fn check_equal_expr_type(left_ty: &DataType, right_ty: &DataType) -> bool {
208+
match (left_ty.remove_nullable(), right_ty.remove_nullable()) {
209+
(DataType::Number(l), DataType::Number(r)) => {
210+
(l.is_integer() && r.is_integer()) || (l.is_float() && r.is_float())
211+
}
212+
(DataType::Decimal(_), DataType::Decimal(_)) => true,
213+
(DataType::Array(box l), DataType::Array(box r)) => Self::check_equal_expr_type(&l, &r),
214+
(DataType::Map(box l), DataType::Map(box r)) => Self::check_equal_expr_type(&l, &r),
215+
(DataType::Tuple(l_tys), DataType::Tuple(r_tys)) => {
216+
l_tys.len() == r_tys.len()
217+
&& l_tys
218+
.iter()
219+
.zip(r_tys.iter())
220+
.all(|(l_ty, r_ty)| Self::check_equal_expr_type(l_ty, r_ty))
221+
}
222+
(_, _) => left_ty.eq(right_ty),
223+
}
191224
}
192225

193226
fn add_expr_predicate(&mut self, expr: &ScalarExpr, new_predicate: Predicate) -> Result<()> {

tests/sqllogictests/suites/mode/standalone/explain/infer_filter.test

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,20 @@ drop table if exists t2;
77
statement ok
88
drop table if exists t3;
99

10+
statement ok
11+
drop table if exists t4;
12+
1013
statement ok
1114
create table t1(a int not null, b int not null);
1215

1316
statement ok
1417
create table t2(a int not null, b int not null);
1518

1619
statement ok
17-
create table t3(a int not null, b int not null);
20+
create table t3(a int not null, b int null);
21+
22+
statement ok
23+
create table t4(a int not null, b variant null);
1824

1925
# a = 1
2026
query T
@@ -753,9 +759,66 @@ HashJoin
753759
├── apply join filters: [#1]
754760
└── estimated rows: 0.00
755761

762+
query T
763+
explain select * from t4 where a = b and strip_null_value(b) is not null;
764+
----
765+
Filter
766+
├── output columns: [t4.a (#0), t4.b (#1)]
767+
├── filters: [is_true(CAST(t4.a (#0) AS Int32 NULL) = TRY_CAST(t4.b (#1) AS Int32 NULL)), is_not_null(strip_null_value(t4.b (#1)))]
768+
├── estimated rows: 0.00
769+
└── TableScan
770+
├── table: default.default.t4
771+
├── output columns: [a (#0), b (#1)]
772+
├── read rows: 0
773+
├── read size: 0
774+
├── partitions total: 0
775+
├── partitions scanned: 0
776+
├── push downs: [filters: [and_filters(CAST(t4.a (#0) AS Int32 NULL) = TRY_CAST(t4.b (#1) AS Int32 NULL), is_not_null(strip_null_value(t4.b (#1))))], limit: NONE]
777+
└── estimated rows: 0.00
778+
779+
query T
780+
explain select * from t3 join t4 on t3.b = t4.b where strip_null_value(t4.b) is not null;
781+
----
782+
HashJoin
783+
├── output columns: [t3.a (#0), t3.b (#1), t4.a (#2), t4.b (#3)]
784+
├── join type: INNER
785+
├── build keys: [CAST(t4.b (#3) AS Int32 NULL)]
786+
├── probe keys: [t3.b (#1)]
787+
├── keys is null equal: [false]
788+
├── filters: []
789+
├── build join filters:
790+
│ └── filter id:0, build key:CAST(t4.b (#3) AS Int32 NULL), probe key:t3.b (#1), filter type:bloom,inlist,min_max
791+
├── estimated rows: 0.00
792+
├── Filter(Build)
793+
│ ├── output columns: [t4.a (#2), t4.b (#3)]
794+
│ ├── filters: [is_not_null(strip_null_value(t4.b (#3)))]
795+
│ ├── estimated rows: 0.00
796+
│ └── TableScan
797+
│ ├── table: default.default.t4
798+
│ ├── output columns: [a (#2), b (#3)]
799+
│ ├── read rows: 0
800+
│ ├── read size: 0
801+
│ ├── partitions total: 0
802+
│ ├── partitions scanned: 0
803+
│ ├── push downs: [filters: [is_not_null(strip_null_value(t4.b (#3)))], limit: NONE]
804+
│ └── estimated rows: 0.00
805+
└── TableScan(Probe)
806+
├── table: default.default.t3
807+
├── output columns: [a (#0), b (#1)]
808+
├── read rows: 0
809+
├── read size: 0
810+
├── partitions total: 0
811+
├── partitions scanned: 0
812+
├── push downs: [filters: [], limit: NONE]
813+
├── apply join filters: [#0]
814+
└── estimated rows: 0.00
815+
756816
statement ok
757817
drop table if exists t3;
758818

819+
statement ok
820+
drop table if exists t4;
821+
759822
# merge predicates with different data types.
760823
statement ok
761824
create or replace table t1(id BIGINT NOT NULL);
@@ -815,3 +878,4 @@ drop table if exists t1;
815878

816879
statement ok
817880
drop table if exists t2;
881+

tests/sqllogictests/suites/mode/standalone/explain_native/infer_filter.test

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,20 @@ drop table if exists t2;
77
statement ok
88
drop table if exists t3;
99

10+
statement ok
11+
drop table if exists t4;
12+
1013
statement ok
1114
create table t1(a int not null, b int not null);
1215

1316
statement ok
1417
create table t2(a int not null, b int not null);
1518

1619
statement ok
17-
create table t3(a int not null, b int not null);
20+
create table t3(a int not null, b int null);
21+
22+
statement ok
23+
create table t4(a int not null, b variant null);
1824

1925
# a = 1
2026
query T
@@ -621,6 +627,52 @@ HashJoin
621627
├── apply join filters: [#1]
622628
└── estimated rows: 0.00
623629

630+
query T
631+
explain select * from t4 where a = b and strip_null_value(b) is not null;
632+
----
633+
TableScan
634+
├── table: default.default.t4
635+
├── output columns: [a (#0), b (#1)]
636+
├── read rows: 0
637+
├── read size: 0
638+
├── partitions total: 0
639+
├── partitions scanned: 0
640+
├── push downs: [filters: [and_filters(CAST(t4.a (#0) AS Int32 NULL) = TRY_CAST(t4.b (#1) AS Int32 NULL), is_not_null(strip_null_value(t4.b (#1))))], limit: NONE]
641+
└── estimated rows: 0.00
642+
643+
query T
644+
explain select * from t3 join t4 on t3.b = t4.b where strip_null_value(t4.b) is not null;
645+
----
646+
HashJoin
647+
├── output columns: [t3.a (#0), t3.b (#1), t4.a (#2), t4.b (#3)]
648+
├── join type: INNER
649+
├── build keys: [CAST(t4.b (#3) AS Int32 NULL)]
650+
├── probe keys: [t3.b (#1)]
651+
├── keys is null equal: [false]
652+
├── filters: []
653+
├── build join filters:
654+
│ └── filter id:0, build key:CAST(t4.b (#3) AS Int32 NULL), probe key:t3.b (#1), filter type:bloom,inlist,min_max
655+
├── estimated rows: 0.00
656+
├── TableScan(Build)
657+
│ ├── table: default.default.t4
658+
│ ├── output columns: [a (#2), b (#3)]
659+
│ ├── read rows: 0
660+
│ ├── read size: 0
661+
│ ├── partitions total: 0
662+
│ ├── partitions scanned: 0
663+
│ ├── push downs: [filters: [is_not_null(strip_null_value(t4.b (#3)))], limit: NONE]
664+
│ └── estimated rows: 0.00
665+
└── TableScan(Probe)
666+
├── table: default.default.t3
667+
├── output columns: [a (#0), b (#1)]
668+
├── read rows: 0
669+
├── read size: 0
670+
├── partitions total: 0
671+
├── partitions scanned: 0
672+
├── push downs: [filters: [], limit: NONE]
673+
├── apply join filters: [#0]
674+
└── estimated rows: 0.00
675+
624676
statement ok
625677
drop table if exists t1;
626678

@@ -630,13 +682,10 @@ drop table if exists t2;
630682
statement ok
631683
drop table if exists t3;
632684

633-
# merge predicates with different data types.
634685
statement ok
635-
drop table if exists t1;
636-
637-
statement ok
638-
drop table if exists t2;
686+
drop table if exists t4;
639687

688+
# merge predicates with different data types.
640689
statement ok
641690
create table t1(id BIGINT NOT NULL);
642691

@@ -687,3 +736,4 @@ drop table if exists t1;
687736

688737
statement ok
689738
drop table if exists t2;
739+

0 commit comments

Comments
 (0)