diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 3dc45edf4a23f..8eb03dc182ae9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -579,13 +579,23 @@ class RegionBuilderHelper { return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::max_unsigned: assert(!allComplex); - if (allFloatingPoint) - return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1); + if (!allInteger || allBool) { + if (emitError) { + emitError() << "unsupported operation: unsigned max not on uint"; + return nullptr; + } + llvm_unreachable("unsupported operation: unsigned max not on uint"); + } return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::min_unsigned: assert(!allComplex); - if (allFloatingPoint) - return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1); + if (!allInteger || allBool) { + if (emitError) { + emitError() << "unsupported operation: unsigned min not on uint"; + return nullptr; + } + llvm_unreachable("unsupported operation: unsigned min not on uint"); + } return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::powf: assert(allFloatingPoint); diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 254458a978828..fb2570c7bb498 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -532,9 +532,9 @@ def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MaximumFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if ( + _is_integer_type(lhs.type) and not _is_bool_type(lhs.type) + ) or _is_index_type(lhs.type): return arith.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}") @@ -546,9 +546,9 @@ def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MinimumFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if ( + _is_integer_type(lhs.type) and not _is_bool_type(lhs.type) + ) or _is_index_type(lhs.type): return arith.MinUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}") @@ -634,6 +634,12 @@ def _is_index_type(t: Type) -> bool: return IndexType.isinstance(t) +def _is_bool_type(t: Type) -> bool: + if not IntegerType.isinstance(t): + return False + return IntegerType(t).width == 1 + + def _get_floating_point_width(t: Type) -> int: # TODO: Create a FloatType in the Python API and implement the switch # there. diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index 8f22cc749bee9..ffea6ad4c5b50 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -104,16 +104,3 @@ func.func @pooling_nhwc_min_unsigned_integer(%input: tensor, %filte // CHECK: @pooling_nhwc_min_unsigned_integer // CHECK: linalg.pooling_nhwc_min_unsigned // CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> - -// ----- - -func.func @pooling_nhwc_min_unsigned_float(%input: tensor, %filter: tensor, %output: tensor) -> tensor { - %0 = linalg.pooling_nhwc_min_unsigned - {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins (%input, %filter: tensor, tensor) - outs (%output: tensor) -> tensor - return %0 : tensor -} -// CHECK: @pooling_nhwc_min_unsigned_float -// CHECK: linalg.pooling_nhwc_min -// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir index 552a0abaa797c..4ecf685b4c695 100644 --- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir +++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir @@ -80,6 +80,20 @@ func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %a // ----- +func.func @pooling_nhwc_max_unsigned_float( + %input: tensor, + %filter: tensor, + %init_val: tensor) -> tensor { + // CHECK: unsupported operation: unsigned max not on uint + linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init_val: tensor) -> tensor + return %init_val : tensor +} + +// ----- + func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) { // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32') linalg.exp ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>) @@ -349,4 +363,3 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref< linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>) return } - diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index a93e9799ceb3f..c2a8f24624d8e 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -705,6 +705,23 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x return %res : tensor<1x2x2x1xf32> } +// ----- + +// CHECK-LABEL: func @pooling_nhwc_max_unsigned_tensor +// CHECK: %{{.+}} = linalg.pooling_nhwc_max_unsigned +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> +func.func @pooling_nhwc_max_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> { + %fake = tensor.empty() : tensor<3x3xi32> + %init = tensor.empty() : tensor<1x2x2x1xi32> + %cst = arith.constant 0 : i32 + %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> + %res = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>) + outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> + return %res : tensor<1x2x2x1xi32> +} + // ----- // CHECK-LABEL: func @pooling_nwc_max_tensor // CHECK: %{{.+}} = linalg.pooling_nwc_max @@ -1017,6 +1034,23 @@ func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x // ----- +// CHECK-LABEL: func @pooling_nhwc_min_unsigned_tensor +// CHECK: %{{.+}} = linalg.pooling_nhwc_min_unsigned +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> +func.func @pooling_nhwc_min_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> { + %fake = tensor.empty() : tensor<3x3xi32> + %init = tensor.empty() : tensor<1x2x2x1xi32> + %cst = arith.constant 0 : i32 + %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> + %res = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>) + outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> + return %res : tensor<1x2x2x1xi32> +} + +// ----- + // CHECK-LABEL: func @pooling_nwc_min_tensor // CHECK: %{{.+}} = linalg.pooling_nwc_min // CHECK-SAME: dilations = dense<1> : tensor<1xi64> diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir index 72acf43361f50..60a4c555fa19a 100644 --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -131,10 +131,10 @@ func.func @pooling_nhwc_max(%input: tensor, %filter: tensor<1x?xf32 } // CHECK-LABEL: @pooling_nhwc_max_unsigned -// CHECK-SAME: %[[ARG0:.+]]: tensor, -// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> -// CHECK-SAME: %[[ARG2:.+]]: tensor -func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor<1x?xi32>, %init: tensor) -> tensor { // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] @@ -142,10 +142,10 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tenso // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins (%input, %filter: tensor, tensor<1x?xf32>) - outs (%init: tensor) -> tensor + ins (%input, %filter: tensor, tensor<1x?xi32>) + outs (%init: tensor) -> tensor // CHECK: return %[[RES]] - return %0 : tensor + return %0 : tensor } // CHECK-LABEL: @pooling_nhwc_min @@ -167,10 +167,10 @@ func.func @pooling_nhwc_min(%input: tensor, %filter: tensor<1x?xf32 } // CHECK-LABEL: @pooling_nhwc_min_unsigned -// CHECK-SAME: %[[ARG0:.+]]: tensor, -// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32> -// CHECK-SAME: %[[ARG2:.+]]: tensor -func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor<1x?xf32>, %init: tensor) -> tensor { +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor<1x?xi32>, %init: tensor) -> tensor { // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] @@ -178,10 +178,10 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tenso // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins (%input, %filter: tensor, tensor<1x?xf32>) - outs (%init: tensor) -> tensor + ins (%input, %filter: tensor, tensor<1x?xi32>) + outs (%init: tensor) -> tensor // CHECK: return %[[RES]] - return %0 : tensor + return %0 : tensor } // CHECK-LABEL: @pooling_nchw_max diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py index 4ce0fbc1dbe53..0df87de6393d8 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py @@ -150,3 +150,51 @@ def test_f32f32_min_pooling(input, shape, init_result): print(module) + +with Context() as ctx, Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f32 = F32Type.get() + bool_t = IntegerType.get_signless(1) + + # CHECK: bool_max_unsigned_error: Unsupported 'max_unsigned' operands + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 4, 16, 1), f32), + RankedTensorType.get((2, 2), f32), + RankedTensorType.get((1, 2, 4, 1), bool_t), + ) + def test_bool_i1_max_unsigned_pooling_error(input, shape, init_result): + try: + pooling_poly( + input, + shape, + outs=[init_result], + reduce=BinaryFn.max_unsigned, + cast=TypeFn.cast_unsigned, + strides=[2, 4], + dilations=[1, 2], + ) + except NotImplementedError as e: + print(f"bool_max_unsigned_error: {e}") + return init_result + + # CHECK: float_max_unsigned_error: Unsupported 'max_unsigned' operands + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 4, 16, 1), f32), + RankedTensorType.get((2, 2), f32), + RankedTensorType.get((1, 2, 4, 1), f32), + ) + def test_f32f32_max_unsigned_pooling_error(input, shape, init_result): + try: + pooling_poly( + input, + shape, + outs=[init_result], + reduce=BinaryFn.max_unsigned, + cast=TypeFn.cast_unsigned, + strides=[2, 4], + dilations=[1, 2], + ) + except NotImplementedError as e: + print(f"float_max_unsigned_error: {e}") + return init_result