diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 938bc73d43996..6a84ead33d5c2 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -757,7 +757,7 @@ class TruncFConversion : public OpConversionPattern { return rewriter.notifyMatchFailure(castOp, "unsupported cast destination type"); - if (!castOp.areCastCompatible(operandType, dstType)) + if (!emitc::CastOp::areCastCompatible(operandType, dstType)) return rewriter.notifyMatchFailure(castOp, "cast-incompatible types"); rewriter.replaceOpWithNewOp(castOp, dstType, @@ -787,7 +787,7 @@ class ExtFConversion : public OpConversionPattern { return rewriter.notifyMatchFailure(castOp, "unsupported cast destination type"); - if (!castOp.areCastCompatible(operandType, dstType)) + if (!emitc::CastOp::areCastCompatible(operandType, dstType)) return rewriter.notifyMatchFailure(castOp, "cast-incompatible types"); rewriter.replaceOpWithNewOp(castOp, dstType, diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 8ed1d609b9181..66421c2f6fff6 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -313,6 +313,10 @@ LogicalResult emitc::AssignOp::verify() { bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { Type input = inputs.front(), output = outputs.front(); + // Opaque types are always allowed + if (isa(input) || isa(output)) + return true; + // Cast to array is only possible from an array if (isa(input) != isa(output)) return false; diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index ad70ea61cb295..80a33b2b9621f 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -36,11 +36,15 @@ emitc.func private @extern(i32) attributes {specifiers = ["extern"]} func.func @cast(%arg0: i32) { %1 = emitc.cast %arg0: i32 to f32 + %2 = emitc.cast %1: f32 to !emitc.opaque<"some type"> + %3 = emitc.cast %2: !emitc.opaque<"some type"> to !emitc.size_t return } func.func @cast_array(%arg : !emitc.array<4xf32>) { %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> ref + %2 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.opaque<"some type"> + %3 = emitc.cast %2: !emitc.opaque<"some type"> to !emitc.array<4xf32> ref return }