Skip to content

Commit 03ef5fc

Browse files
committed
[mlir][linalg] Reject unsigned pooling on non-integer element types
1 parent 8785595 commit 03ef5fc

File tree

4 files changed

+76
-19
lines changed

4 files changed

+76
-19
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,13 +579,23 @@ class RegionBuilderHelper {
579579
return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
580580
case BinaryFn::max_unsigned:
581581
assert(!allComplex);
582-
if (allFloatingPoint)
583-
return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
582+
if (!allInteger || allBool) {
583+
if (emitError) {
584+
emitError() << "unsupported operation: unsigned max not on uint";
585+
return nullptr;
586+
}
587+
llvm_unreachable("unsupported operation: unsigned max not on uint");
588+
}
584589
return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
585590
case BinaryFn::min_unsigned:
586591
assert(!allComplex);
587-
if (allFloatingPoint)
588-
return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
592+
if (!allInteger || allBool) {
593+
if (emitError) {
594+
emitError() << "unsupported operation: unsigned min not on uint";
595+
return nullptr;
596+
}
597+
llvm_unreachable("unsupported operation: unsigned min not on uint");
598+
}
589599
return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
590600
case BinaryFn::powf:
591601
assert(allFloatingPoint);

mlir/test/Dialect/Linalg/named-ops-fail.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,20 @@ func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %a
8080

8181
// -----
8282

83+
func.func @pooling_nhwc_max_unsigned_float(
84+
%input: tensor<?x?x?x?xf32>,
85+
%filter: tensor<?x?xf32>,
86+
%init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
87+
// CHECK: unsupported operation: unsigned max not on uint
88+
linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
89+
strides = dense<1> : tensor<2xi64>}
90+
ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
91+
outs (%init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
92+
return %init_val : tensor<?x?x?x?xf32>
93+
}
94+
95+
// -----
96+
8397
func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
8498
// CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
8599
linalg.exp ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
@@ -349,4 +363,3 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<
349363
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
350364
return
351365
}
352-

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,23 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
705705
return %res : tensor<1x2x2x1xf32>
706706
}
707707

708+
// -----
709+
710+
// CHECK-LABEL: func @pooling_nhwc_max_unsigned_tensor
711+
// CHECK: %{{.+}} = linalg.pooling_nhwc_max_unsigned
712+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
713+
// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
714+
func.func @pooling_nhwc_max_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
715+
%fake = tensor.empty() : tensor<3x3xi32>
716+
%init = tensor.empty() : tensor<1x2x2x1xi32>
717+
%cst = arith.constant 0 : i32
718+
%fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
719+
%res = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
720+
ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
721+
outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
722+
return %res : tensor<1x2x2x1xi32>
723+
}
724+
708725
// -----
709726
// CHECK-LABEL: func @pooling_nwc_max_tensor
710727
// CHECK: %{{.+}} = linalg.pooling_nwc_max
@@ -1017,6 +1034,23 @@ func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
10171034

10181035
// -----
10191036

1037+
// CHECK-LABEL: func @pooling_nhwc_min_unsigned_tensor
1038+
// CHECK: %{{.+}} = linalg.pooling_nhwc_min_unsigned
1039+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
1040+
// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
1041+
func.func @pooling_nhwc_min_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
1042+
%fake = tensor.empty() : tensor<3x3xi32>
1043+
%init = tensor.empty() : tensor<1x2x2x1xi32>
1044+
%cst = arith.constant 0 : i32
1045+
%fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
1046+
%res = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
1047+
ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
1048+
outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
1049+
return %res : tensor<1x2x2x1xi32>
1050+
}
1051+
1052+
// -----
1053+
10201054
// CHECK-LABEL: func @pooling_nwc_min_tensor
10211055
// CHECK: %{{.+}} = linalg.pooling_nwc_min
10221056
// CHECK-SAME: dilations = dense<1> : tensor<1xi64>

mlir/test/Dialect/Linalg/transform-op-decompose.mlir

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -131,21 +131,21 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
131131
}
132132

133133
// CHECK-LABEL: @pooling_nhwc_max_unsigned
134-
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
135-
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
136-
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
137-
func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
134+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
135+
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
136+
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
137+
func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
138138
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
139139
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
140140
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
141141
// CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned
142142
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
143143
%0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
144144
strides = dense<1> : tensor<2xi64>}
145-
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
146-
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
145+
ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
146+
outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
147147
// CHECK: return %[[RES]]
148-
return %0 : tensor<?x1x?x?xf32>
148+
return %0 : tensor<?x1x?x?xi32>
149149
}
150150

151151
// CHECK-LABEL: @pooling_nhwc_min
@@ -167,21 +167,21 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
167167
}
168168

169169
// CHECK-LABEL: @pooling_nhwc_min_unsigned
170-
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
171-
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
172-
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
173-
func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
170+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
171+
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
172+
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
173+
func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
174174
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
175175
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
176176
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
177177
// CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned
178178
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
179179
%0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
180180
strides = dense<1> : tensor<2xi64>}
181-
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
182-
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
181+
ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
182+
outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
183183
// CHECK: return %[[RES]]
184-
return %0 : tensor<?x1x?x?xf32>
184+
return %0 : tensor<?x1x?x?xi32>
185185
}
186186

187187
// CHECK-LABEL: @pooling_nchw_max

0 commit comments

Comments
 (0)