Skip to content

Commit f5b7376

Browse files
authored
[mlir][MemRef] Add UB as a dependent dialect and use ub.poison for Mem2Reg (#168066)
This patch adds `ub` as a dependent dialect to `memref`, and uses `ub.poison` as the default value in `AllocaOp::getDefaultValue` for the mem2reg pass. This aligns the behavior of `mem2reg` with LLVM, where loading a value before having a value should be poison. --------- Signed-off-by: Fabian Mora <[email protected]>
1 parent f210fc1 commit f5b7376

File tree

5 files changed

+13
-19
lines changed

5 files changed

+13
-19
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@ def MemRef_Dialect : Dialect {
1919
manipulation ops, which are not strongly associated with any particular
2020
other dialect or domain abstraction.
2121
}];
22-
let dependentDialects = ["arith::ArithDialect"];
22+
let dependentDialects = [
23+
// `arith` is a dependency because it is used to materialize constants,
24+
// and in some canonicalization patterns.
25+
"arith::ArithDialect",
26+
// `ub` is a dependency because `AllocaOp::getDefaultValue` can produce a
27+
// `ub.poison` value.
28+
"ub::UBDialect"
29+
];
2330
let hasConstantMaterializer = 1;
2431
}
2532

mlir/lib/Dialect/MemRef/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
2525
MLIRMemorySlotInterfaces
2626
MLIRShapedOpInterfaces
2727
MLIRSideEffectInterfaces
28+
MLIRUBDialect
2829
MLIRValueBoundsOpInterface
2930
MLIRViewLikeInterface
3031
)

mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1111
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
1212
#include "mlir/Dialect/MemRef/IR/MemRef.h"
13+
#include "mlir/Dialect/UB/IR/UBOps.h"
1314
#include "mlir/IR/BuiltinTypes.h"
1415
#include "mlir/Interfaces/MemorySlotInterfaces.h"
1516
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"

mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
1515
#include "mlir/Dialect/MemRef/IR/MemRef.h"
16+
#include "mlir/Dialect/UB/IR/UBOps.h"
1617
#include "mlir/IR/BuiltinDialect.h"
1718
#include "mlir/IR/BuiltinTypes.h"
1819
#include "mlir/IR/Matchers.h"
@@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
6162
// Interfaces for AllocaOp
6263
//===----------------------------------------------------------------------===//
6364

64-
static bool isSupportedElementType(Type type) {
65-
return llvm::isa<MemRefType>(type) ||
66-
OpBuilder(type.getContext()).getZeroAttr(type);
67-
}
68-
6965
SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
7066
MemRefType type = getType();
71-
if (!isSupportedElementType(type.getElementType()))
72-
return {};
7367
if (!type.hasStaticShape())
7468
return {};
7569
// Make sure the memref contains only a single element.
@@ -81,16 +75,7 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
8175

8276
Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
8377
OpBuilder &builder) {
84-
assert(isSupportedElementType(slot.elemType));
85-
// TODO: support more types.
86-
return TypeSwitch<Type, Value>(slot.elemType)
87-
.Case([&](MemRefType t) {
88-
return memref::AllocaOp::create(builder, getLoc(), t);
89-
})
90-
.Default([&](Type t) {
91-
return arith::ConstantOp::create(builder, getLoc(), t,
92-
builder.getZeroAttr(t));
93-
});
78+
return ub::PoisonOp::create(builder, getLoc(), slot.elemType);
9479
}
9580

9681
std::optional<PromotableAllocationOpInterface>

mlir/test/Dialect/MemRef/mem2reg.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func.func @basic() -> i32 {
1818
// CHECK-LABEL: func.func @basic_default
1919
func.func @basic_default() -> i32 {
2020
// CHECK-NOT: = memref.alloca
21-
// CHECK: %[[RES:.*]] = arith.constant 0 : i32
21+
// CHECK: %[[RES:.*]] = ub.poison : i32
2222
// CHECK-NOT: = memref.alloca
2323
%0 = arith.constant 5 : i32
2424
%1 = memref.alloca() : memref<i32>

0 commit comments

Comments
 (0)