From 3852cc28f517273121e83ac3b21094f065fc5749 Mon Sep 17 00:00:00 2001 From: Andrey Timonin Date: Sat, 8 Feb 2025 21:40:30 +0300 Subject: [PATCH 1/2] [mlir][emitc] Add an option to cast array type to ptr type --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 3 +-- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 11 ++++++----- mlir/test/Dialect/EmitC/invalid_ops.mlir | 14 +++++++++++--- mlir/test/Dialect/EmitC/ops.mlir | 5 +++++ 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 4fbce995ce5b8..360f2e8434363 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -266,8 +266,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { def EmitC_CastOp : EmitC_Op<"cast", [CExpression, - DeclareOpInterfaceMethods, - SameOperandsAndResultShape]> { + DeclareOpInterfaceMethods]> { let summary = "Cast operation"; let description = [{ The `emitc.cast` operation performs an explicit type conversion and is emitted diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 728a2d33f46e7..01effa5734caa 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -247,11 +247,12 @@ LogicalResult emitc::AssignOp::verify() { bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { Type input = inputs.front(), output = outputs.front(); - return ( - (emitc::isIntegerIndexOrOpaqueType(input) || - emitc::isSupportedFloatType(input) || isa(input)) && - (emitc::isIntegerIndexOrOpaqueType(output) || - emitc::isSupportedFloatType(output) || isa(output))); + return ((emitc::isIntegerIndexOrOpaqueType(input) || + emitc::isSupportedFloatType(input) || + isa(input) || isa(input)) && + (emitc::isIntegerIndexOrOpaqueType(output) || + emitc::isSupportedFloatType(output) || + isa(output))); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index a0d8d7f59de11..c40195dd3473a 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -130,9 +130,17 @@ func.func @cast_tensor(%arg : tensor) { // ----- -func.func @cast_array(%arg : !emitc.array<4xf32>) { - // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}} - %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> +func.func @cast_to_array(%arg : f32) { + // expected-error @+1 {{'emitc.cast' op operand type 'f32' and result type '!emitc.array<4xf32>' are cast incompatible}} + %1 = emitc.cast %arg: f32 to !emitc.array<4xf32> + return +} + +// ----- + +func.func @cast_pointer_to_array(%arg : !emitc.ptr) { + // expected-error @+1 {{'emitc.cast' op operand type '!emitc.ptr' and result type '!emitc.array<3xi32>' are cast incompatible}} + %1 = emitc.cast %arg: !emitc.ptr to !emitc.array<3xi32> return } diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 7fd0a2d020397..c6f90f5600855 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -39,6 +39,11 @@ func.func @cast(%arg0: i32) { return } +func.func @cast_array_to_pointer(%arg0: !emitc.array<3xi32>) { + %1 = emitc.cast %arg0: !emitc.array<3xi32> to !emitc.ptr + return +} + func.func @c() { %1 = "emitc.constant"(){value = 42 : i32} : () -> i32 %2 = "emitc.constant"(){value = 42 : index} : () -> !emitc.size_t From 2199c2fb559b98cf3626feea39f6ed3f40e6fa21 Mon Sep 17 00:00:00 2001 From: EtoAndruwa Date: Sun, 9 Feb 2025 00:41:40 +0300 Subject: [PATCH 2/2] [mlir][emitc] Strengthen type and rank checks --- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 29 +++++++++++++++--------- mlir/test/Dialect/EmitC/invalid_ops.mlir | 24 ++++++++++++++++++++ 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 01effa5734caa..80581a9f814ac 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -247,12 +247,19 @@ LogicalResult emitc::AssignOp::verify() { bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { Type input = inputs.front(), output = outputs.front(); - return ((emitc::isIntegerIndexOrOpaqueType(input) || - emitc::isSupportedFloatType(input) || - isa(input) || isa(input)) && - (emitc::isIntegerIndexOrOpaqueType(output) || - emitc::isSupportedFloatType(output) || - isa(output))); + if (auto arrayType = dyn_cast(input)) { + if (auto pointerType = dyn_cast(output)) { + return (arrayType.getElementType() == pointerType.getPointee()) && + arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1; + } + return false; + } + + return ( + (emitc::isIntegerIndexOrOpaqueType(input) || + emitc::isSupportedFloatType(input) || isa(input)) && + (emitc::isIntegerIndexOrOpaqueType(output) || + emitc::isSupportedFloatType(output) || isa(output))); } //===----------------------------------------------------------------------===// @@ -700,9 +707,9 @@ void IfOp::print(OpAsmPrinter &p) { /// Given the region at `index`, or the parent operation if `index` is None, /// return the successor regions. These are the regions that may be selected -/// during the flow of control. `operands` is a set of optional attributes that -/// correspond to a constant value for each operand, or null if that operand is -/// not a constant. +/// during the flow of control. `operands` is a set of optional attributes +/// that correspond to a constant value for each operand, or null if that +/// operand is not a constant. void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. @@ -1000,8 +1007,8 @@ emitc::ArrayType::cloneWith(std::optional> shape, LogicalResult mlir::emitc::LValueType::verify( llvm::function_ref emitError, mlir::Type value) { - // Check that the wrapped type is valid. This especially forbids nested lvalue - // types. + // Check that the wrapped type is valid. This especially forbids nested + // lvalue types. if (!isSupportedEmitCType(value)) return emitError() << "!emitc.lvalue must wrap supported emitc type, but got " << value; diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index c40195dd3473a..b58981689919b 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -138,6 +138,30 @@ func.func @cast_to_array(%arg : f32) { // ----- +func.func @cast_multidimensional_array(%arg : !emitc.array<1x2xi32>) { + // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<1x2xi32>' and result type '!emitc.ptr' are cast incompatible}} + %1 = emitc.cast %arg: !emitc.array<1x2xi32> to !emitc.ptr + return +} + +// ----- + +func.func @cast_array_zero_rank(%arg : !emitc.array<0xi32>) { + // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<0xi32>' and result type '!emitc.ptr' are cast incompatible}} + %1 = emitc.cast %arg: !emitc.array<0xi32> to !emitc.ptr + return +} + +// ----- + +func.func @cast_array_to_pointer_types_mismatch(%arg : !emitc.array<3xi32>) { + // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<3xi32>' and result type '!emitc.ptr' are cast incompatible}} + %1 = emitc.cast %arg: !emitc.array<3xi32> to !emitc.ptr + return +} + +// ----- + func.func @cast_pointer_to_array(%arg : !emitc.ptr) { // expected-error @+1 {{'emitc.cast' op operand type '!emitc.ptr' and result type '!emitc.array<3xi32>' are cast incompatible}} %1 = emitc.cast %arg: !emitc.ptr to !emitc.array<3xi32>