From 8a7b9275b625274d92c4376f3e7ad9ffeb328d30 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 16 Dec 2024 11:01:45 +0000 Subject: [PATCH 1/3] EmitC: Allow casts between opaque and float types --- mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 938bc73d43996..b68ecc908c089 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -757,7 +757,9 @@ class TruncFConversion : public OpConversionPattern { return rewriter.notifyMatchFailure(castOp, "unsupported cast destination type"); - if (!castOp.areCastCompatible(operandType, dstType)) + if (!isa(dstType) && + !isa(operandType) && + !castOp.areCastCompatible(operandType, dstType)) return rewriter.notifyMatchFailure(castOp, "cast-incompatible types"); rewriter.replaceOpWithNewOp(castOp, dstType, @@ -787,7 +789,9 @@ class ExtFConversion : public OpConversionPattern { return rewriter.notifyMatchFailure(castOp, "unsupported cast destination type"); - if (!castOp.areCastCompatible(operandType, dstType)) + if (!isa(dstType) && + !isa(operandType) && + !castOp.areCastCompatible(operandType, dstType)) return rewriter.notifyMatchFailure(castOp, "cast-incompatible types"); rewriter.replaceOpWithNewOp(castOp, dstType, From 96d2f928e5b1f3f553f1c354ef7aa12922c95d28 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 16 Dec 2024 12:17:40 +0000 Subject: [PATCH 2/3] Use EmitC cast compatibility check --- mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index b68ecc908c089..6a84ead33d5c2 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -757,9 +757,7 @@ class TruncFConversion : public OpConversionPattern { return rewriter.notifyMatchFailure(castOp, "unsupported cast destination type"); - if (!isa(dstType) && - !isa(operandType) && - !castOp.areCastCompatible(operandType, dstType)) + if (!emitc::CastOp::areCastCompatible(operandType, dstType)) return rewriter.notifyMatchFailure(castOp, "cast-incompatible types"); rewriter.replaceOpWithNewOp(castOp, dstType, @@ -789,9 +787,7 @@ class ExtFConversion : public OpConversionPattern { return rewriter.notifyMatchFailure(castOp, "unsupported cast destination type"); - if (!isa(dstType) && - !isa(operandType) && - !castOp.areCastCompatible(operandType, dstType)) + if (!emitc::CastOp::areCastCompatible(operandType, dstType)) return rewriter.notifyMatchFailure(castOp, "cast-incompatible types"); rewriter.replaceOpWithNewOp(castOp, dstType, From 3a6675a669f971fd8857e15eb067e9cfdff058f9 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 16 Dec 2024 13:18:15 +0000 Subject: [PATCH 3/3] Allow opaque types in casts (also for array types) --- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 4 ++++ mlir/test/Dialect/EmitC/ops.mlir | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 8ed1d609b9181..66421c2f6fff6 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -313,6 +313,10 @@ LogicalResult emitc::AssignOp::verify() { bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { Type input = inputs.front(), output = outputs.front(); + // Opaque types are always allowed + if (isa(input) || isa(output)) + return true; + // Cast to array is only possible from an array if (isa(input) != isa(output)) return false; diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index ad70ea61cb295..80a33b2b9621f 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -36,11 +36,15 @@ emitc.func private @extern(i32) attributes {specifiers = ["extern"]} func.func @cast(%arg0: i32) { %1 = emitc.cast %arg0: i32 to f32 + %2 = emitc.cast %1: f32 to !emitc.opaque<"some type"> + %3 = emitc.cast %2: !emitc.opaque<"some type"> to !emitc.size_t return } func.func @cast_array(%arg : !emitc.array<4xf32>) { %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> ref + %2 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.opaque<"some type"> + %3 = emitc.cast %2: !emitc.opaque<"some type"> to !emitc.array<4xf32> ref return }