Skip to content

Commit 6937dbb

Browse files
[mlir][memref] Fix alloca lowering with 0 dimensions (#111119)
The `memref.alloca` lowering computed the allocation size incorrectly when there were 0 dimensions. Previously: ``` memref.alloca() : memref<10x0x2xf32> --> llvm.alloca 20xf32 ``` Now: ``` memref.alloca() : memref<10x0x2xf32> --> llvm.alloca 0xf32 ``` From the `llvm.alloca` documentation: ``` Allocating zero bytes is legal, but the returned pointer may not be unique. ```
1 parent 19992ee commit 6937dbb

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,6 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
139139
strides[i] = runningStride;
140140

141141
int64_t staticSize = memRefType.getShape()[i];
142-
if (staticSize == 0)
143-
continue;
144142
bool useSizeAsStride = stride == 1;
145143
if (staticSize == ShapedType::kDynamic)
146144
stride = ShapedType::kDynamic;

mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,22 @@ func.func @static_alloca() -> memref<32x18xf32> {
9595

9696
// -----
9797

98+
// CHECK-LABEL: func @static_alloca_zero()
99+
func.func @static_alloca_zero() -> memref<32x0x18xf32> {
100+
// CHECK: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : i64
101+
// CHECK: %[[sz2:.*]] = llvm.mlir.constant(0 : index) : i64
102+
// CHECK: %[[sz3:.*]] = llvm.mlir.constant(18 : index) : i64
103+
// CHECK: %[[st1:.*]] = llvm.mlir.constant(1 : index) : i64
104+
// CHECK: %[[st2:.*]] = llvm.mlir.constant(0 : index) : i64
105+
// CHECK: %[[num_elems:.*]] = llvm.mlir.constant(0 : index) : i64
106+
// CHECK: %[[allocated:.*]] = llvm.alloca %[[num_elems]] x f32 : (i64) -> !llvm.ptr
107+
%0 = memref.alloca() : memref<32x0x18xf32>
108+
109+
return %0 : memref<32x0x18xf32>
110+
}
111+
112+
// -----
113+
98114
// CHECK-LABEL: func @static_dealloc
99115
func.func @static_dealloc(%static: memref<10x8xf32>) {
100116
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>

0 commit comments

Comments
 (0)