Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def MemRef_Dialect : Dialect {
manipulation ops, which are not strongly associated with any particular
other dialect or domain abstraction.
}];
let dependentDialects = ["arith::ArithDialect"];
let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
let hasConstantMaterializer = 1;
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRMemorySlotInterfaces
MLIRShapedOpInterfaces
MLIRSideEffectInterfaces
MLIRUBDialect
MLIRValueBoundsOpInterface
MLIRViewLikeInterface
)
1 change: 1 addition & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
Expand Down
19 changes: 2 additions & 17 deletions mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
// Interfaces for AllocaOp
//===----------------------------------------------------------------------===//

static bool isSupportedElementType(Type type) {
return llvm::isa<MemRefType>(type) ||
OpBuilder(type.getContext()).getZeroAttr(type);
}

SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
MemRefType type = getType();
if (!isSupportedElementType(type.getElementType()))
return {};
if (!type.hasStaticShape())
return {};
// Make sure the memref contains only a single element.
Expand All @@ -81,16 +75,7 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {

Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
OpBuilder &builder) {
assert(isSupportedElementType(slot.elemType));
// TODO: support more types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](MemRefType t) {
return memref::AllocaOp::create(builder, getLoc(), t);
})
.Default([&](Type t) {
return arith::ConstantOp::create(builder, getLoc(), t,
builder.getZeroAttr(t));
});
return ub::PoisonOp::create(builder, getLoc(), slot.elemType);
}

std::optional<PromotableAllocationOpInterface>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/MemRef/mem2reg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func.func @basic() -> i32 {
// CHECK-LABEL: func.func @basic_default
func.func @basic_default() -> i32 {
// CHECK-NOT: = memref.alloca
// CHECK: %[[RES:.*]] = arith.constant 0 : i32
// CHECK: %[[RES:.*]] = ub.poison : i32
// CHECK-NOT: = memref.alloca
%0 = arith.constant 5 : i32
%1 = memref.alloca() : memref<i32>
Expand Down
Loading