diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 254f54d9e459e..f2f23954d5c19 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2314,7 +2314,8 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { return trueVal; // select %x, true, false => %x - if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) && + if (getType().isSignlessInteger(1) && + matchPattern(adaptor.getTrueValue(), m_One()) && matchPattern(adaptor.getFalseValue(), m_Zero())) return condition; diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index a386a178b7899..f9997ec2796af 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -54,6 +54,18 @@ func.func @select_extui_i1(%arg0: i1) -> i1 { return %res : i1 } +// CHECK-LABEL: @select_no_fold_ui1 +// CHECK: %[[CONST_0:.+]] = "test.constant"() <{value = 0 : i32}> : () -> ui1 +// CHECK: %[[CONST_1:.+]] = "test.constant"() <{value = 1 : i32}> : () -> ui1 +// CHECK-NEXT: %[[RES:.+]] = arith.select %arg0, %[[CONST_1]], %[[CONST_0]] : ui1 +// CHECK-NEXT: return %[[RES]] +func.func @select_no_fold_ui1(%arg0: i1) -> ui1 { + %c0_i1 = "test.constant"() {value = 0 : i32} : () -> ui1 + %c1_i1 = "test.constant"() {value = 1 : i32} : () -> ui1 + %res = arith.select %arg0, %c1_i1, %c0_i1 : ui1 + return %res : ui1 +} + // CHECK-LABEL: @select_cst_false_scalar // CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) // CHECK-NEXT: return %[[ARG1]]