From ce9af7449b59c668caca6fb1d0b7b65d0f7147af Mon Sep 17 00:00:00 2001 From: Johannes de Fine Licht Date: Mon, 17 Mar 2025 14:35:23 +0000 Subject: [PATCH] [MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics. This was lacking a bitcast from the shifted integer type into a float. Other non-struct types than integers and floats will still not be Mem2Reg'ed. Also adds special handling for constants to be emitted as a constant directly rather than relying on followup canonicalization patterns (`memset` of zero is a case that can appear in the wild). --- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 73 ++++++++++++------- .../Dialect/LLVMIR/mem2reg-intrinsics.mlir | 61 +++++++++++----- 2 files changed, 90 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 655316cc5d66d..d1ccb487d2265 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -1051,30 +1051,52 @@ static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot, template static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot, OpBuilder &builder) { - // TODO: Support non-integer types. - return TypeSwitch(slot.elemType) - .Case([&](IntegerType intType) -> Value { - if (intType.getWidth() == 8) - return op.getVal(); - - assert(intType.getWidth() % 8 == 0); - - // Build the memset integer by repeatedly shifting the value and - // or-ing it with the previous value. - uint64_t coveredBits = 8; - Value currentValue = - builder.create(op.getLoc(), intType, op.getVal()); - while (coveredBits < intType.getWidth()) { - Value shiftBy = builder.create(op.getLoc(), intType, - coveredBits); - Value shifted = - builder.create(op.getLoc(), currentValue, shiftBy); - currentValue = - builder.create(op.getLoc(), currentValue, shifted); - coveredBits *= 2; - } + /// Returns an integer value that is `width` bits wide representing the value + /// assigned to the slot by memset. + auto buildMemsetValue = [&](unsigned width) -> Value { + assert(width % 8 == 0); + auto intType = IntegerType::get(op.getContext(), width); + + // If we know the pattern at compile time, we can compute and assign a + // constant directly. + IntegerAttr constantPattern; + if (matchPattern(op.getVal(), m_Constant(&constantPattern))) { + assert(constantPattern.getValue().getBitWidth() == 8); + APInt memsetVal(/*numBits=*/width, /*val=*/0); + for (unsigned loBit = 0; loBit < width; loBit += 8) + memsetVal.insertBits(constantPattern.getValue(), loBit); + return builder.create( + op.getLoc(), IntegerAttr::get(intType, memsetVal)); + } + + // If the output is a single byte, we can return the pattern directly. + if (width == 8) + return op.getVal(); + + // Otherwise build the memset integer at runtime by repeatedly shifting the + // value and or-ing it with the previous value. + uint64_t coveredBits = 8; + Value currentValue = + builder.create(op.getLoc(), intType, op.getVal()); + while (coveredBits < width) { + Value shiftBy = + builder.create(op.getLoc(), intType, coveredBits); + Value shifted = + builder.create(op.getLoc(), currentValue, shiftBy); + currentValue = + builder.create(op.getLoc(), currentValue, shifted); + coveredBits *= 2; + } - return currentValue; + return currentValue; + }; + return TypeSwitch(slot.elemType) + .Case([&](IntegerType type) -> Value { + return buildMemsetValue(type.getWidth()); + }) + .Case([&](FloatType type) -> Value { + Value intVal = buildMemsetValue(type.getWidth()); + return builder.create(op.getLoc(), type, intVal); }) .Default([](Type) -> Value { llvm_unreachable( @@ -1088,11 +1110,10 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses, const DataLayout &dataLayout) { - // TODO: Support non-integer types. bool canConvertType = TypeSwitch(slot.elemType) - .Case([](IntegerType intType) { - return intType.getWidth() % 8 == 0 && intType.getWidth() > 0; + .Case([](auto type) { + return type.getWidth() % 8 == 0 && type.getWidth() > 0; }) .Default([](Type) { return false; }); if (!canConvertType) diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir index 646667505a373..37c2f525a9dcb 100644 --- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir +++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir @@ -23,6 +23,30 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 { // ----- +// CHECK-LABEL: llvm.func @memset_float +// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8) +llvm.func @memset_float(%memset_value: i8) -> f32 { + %one = llvm.mlir.constant(1 : i32) : i32 + %alloca = llvm.alloca %one x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %memset_len = llvm.mlir.constant(4 : i32) : i32 + "llvm.intr.memset"(%alloca, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + // CHECK-NOT: "llvm.intr.memset" + // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32 + // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]] + // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]] + // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]] + // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]] + // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32 + // CHECK-NOT: "llvm.intr.memset" + %load = llvm.load %alloca {alignment = 4 : i64} : !llvm.ptr -> f32 + // CHECK: llvm.return %[[VALUE_FLOAT]] : f32 + llvm.return %load : f32 +} + +// ----- + // CHECK-LABEL: llvm.func @basic_memset_inline // CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8) llvm.func @basic_memset_inline(%memset_value: i8) -> i32 { @@ -53,20 +77,28 @@ llvm.func @basic_memset_constant() -> i32 { %memset_len = llvm.mlir.constant(4 : i32) : i32 "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32 - // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8 - // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32 - // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32 - // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32 - // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32 - // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32 - // CHECK: llvm.return %[[RES]] : i32 + // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32 + // CHECK: llvm.return %[[CONSTANT_VAL]] : i32 llvm.return %2 : i32 } // ----- +// CHECK-LABEL: llvm.func @memset_one_byte_constant +llvm.func @memset_one_byte_constant() -> i8 { + %one = llvm.mlir.constant(1 : i32) : i32 + %alloca = llvm.alloca %one x i8 : (i32) -> !llvm.ptr + // CHECK: %{{.+}} = llvm.mlir.constant(42 : i8) : i8 + %value = llvm.mlir.constant(42 : i8) : i8 + "llvm.intr.memset"(%alloca, %value, %one) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %load = llvm.load %alloca : !llvm.ptr -> i8 + // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK: llvm.return %[[CONSTANT_VAL]] : i8 + llvm.return %load : i8 +} + +// ----- + // CHECK-LABEL: llvm.func @basic_memset_inline_constant llvm.func @basic_memset_inline_constant() -> i32 { %0 = llvm.mlir.constant(1 : i32) : i32 @@ -74,15 +106,8 @@ llvm.func @basic_memset_inline_constant() -> i32 { %memset_value = llvm.mlir.constant(42 : i8) : i8 "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4}> : (!llvm.ptr, i8) -> () %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32 - // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8 - // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32 - // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32 - // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32 - // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32 - // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32 - // CHECK: llvm.return %[[RES]] : i32 + // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32 + // CHECK: llvm.return %[[CONSTANT_VAL]] : i32 llvm.return %2 : i32 }