From 07b3df3518182ac1f9da5840eaa2744801137581 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 12 Feb 2025 12:19:13 +0100 Subject: [PATCH] [mlir] ArithToLLVM: fix memref bitcast lowering (#125148) `arith.bitcast` is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`. `LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal. With the opaque pointers this is a no-op operation for memref so we can just add a separate pattern which removes op if converted types are the same. --- .../Conversion/ArithToLLVM/ArithToLLVM.cpp | 25 +++++++++++++++++ .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 28 ++++++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 754ed89814293..ced18a48766bf 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -54,6 +54,25 @@ struct ConstrainedVectorConvertToLLVMPattern } }; +/// No-op bitcast. Propagate type input arg if converted source and dest types +/// are the same. +struct IdentityBitcastLowering final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + Value src = adaptor.getIn(); + Type resultType = getTypeConverter()->convertType(op.getType()); + if (src.getType() != resultType) + return rewriter.notifyMatchFailure(op, "Types are different"); + + rewriter.replaceOp(op, src); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Straightforward Op Lowerings //===----------------------------------------------------------------------===// @@ -524,6 +543,12 @@ void mlir::arith::registerConvertArithToLLVMInterface( void mlir::arith::populateArithToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + + // Set a higher pattern benefit for IdentityBitcastLowering so it will run + // before BitcastOpLowering. + patterns.add(converter, patterns.getContext(), + /*patternBenefit*/ 10); + // clang-format off patterns.add< AddFOpLowering, diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 1dabacfd8a47c..7daf4ef8717bc 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -577,12 +577,26 @@ func.func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) { // ----- // CHECK-LABEL: @select +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) func.func @select(%arg0 : i1, %arg1 : i32, %arg2 : i32) -> i32 { - // CHECK: = llvm.select %arg0, %arg1, %arg2 : i1, i32 + // CHECK: %[[RES:.*]] = llvm.select %[[ARG0]], %[[ARG1]], %[[ARG2]] : i1, i32 + // CHECK: return %[[RES]] %0 = arith.select %arg0, %arg1, %arg2 : i32 return %0 : i32 } +// CHECK-LABEL: @select_complex +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: complex, %[[ARG2:.*]]: complex) +func.func @select_complex(%arg0 : i1, %arg1 : complex, %arg2 : complex) -> complex { + // CHECK-DAG: %[[ARGC1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : complex to !llvm.struct<(f32, f32)> + // CHECK-DAG: %[[ARGC2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : complex to !llvm.struct<(f32, f32)> + // CHECK: %[[RES:.*]] = llvm.select %[[ARG0]], %[[ARGC1]], %[[ARGC2]] : i1, !llvm.struct<(f32, f32)> + // CHECK: %[[RESC:.*]] = builtin.unrealized_conversion_cast %[[RES]] : !llvm.struct<(f32, f32)> to complex + // CHECK: return %[[RESC]] + %0 = arith.select %arg0, %arg1, %arg2 : complex + return %0 : complex +} + // ----- // CHECK-LABEL: @ceildivsi @@ -727,3 +741,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) { %3 = arith.shli %arg0, %arg1 overflow : i64 return } + +// ----- + +// CHECK-LABEL: func @memref_bitcast +// CHECK-SAME: (%[[ARG:.*]]: memref) +// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref +// CHECK: return %[[V2]] +func.func @memref_bitcast(%1: memref) -> memref { + %2 = arith.bitcast %1 : memref to memref + func.return %2 : memref +}