Skip to content

Commit c6cedb2

Browse files
committed
[mlir][memref] Revert llvm#140730
Reverts llvm#140730 - that turned out not to be an NFC as we originally thought. See the attached test for an example. Many thanks to @Garra1980 for reporting!
1 parent bc0c4db commit c6cedb2

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,10 +1721,6 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
17211721
MemRefDescriptor sourceMemRef(adaptor.getSource());
17221722
auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
17231723

1724-
// Early exit for 0-D corner case.
1725-
if (viewMemRefType.getRank() == 0)
1726-
return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1727-
17281724
// Field 1: Copy the allocated pointer, used for malloc/free.
17291725
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
17301726
auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
@@ -1747,6 +1743,10 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
17471743
rewriter, loc,
17481744
createIndexAttrConstant(rewriter, loc, indexType, offset));
17491745

1746+
// Early exit for 0-D corner case.
1747+
if (viewMemRefType.getRank() == 0)
1748+
return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1749+
17501750
// Fields 4 and 5: Update sizes and strides.
17511751
Value stride = nullptr, nextSize = nullptr;
17521752
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
// RUN: mlir-opt -finalize-memref-to-llvm %s -split-input-file | FileCheck %s
1+
// RUN: mlir-opt -finalize-memref-to-llvm %s -split-input-file | FileCheck --check-prefixes=ALL,CHECK %s
22
// RUN: mlir-opt -finalize-memref-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s
33

44
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
55
// and the generic `convert-to-llvm` pass. This produces slightly different IR
66
// because the conversion target is set up differently.
7-
// RUN: mlir-opt --convert-to-llvm="filter-dialects=memref" --split-input-file %s | FileCheck --check-prefix=CHECK-INTERFACE %s
7+
// RUN: mlir-opt --convert-to-llvm="filter-dialects=memref" --split-input-file %s | FileCheck --check-prefixes=ALL,CHECK-INTERFACE %s
88

99
// CHECK-LABEL: func @view(
1010
// CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index
@@ -132,6 +132,28 @@ func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) {
132132

133133
// -----
134134

135+
// ALL-LABEL: func.func @view_memref_as_rank0(
136+
// ALL-SAME: %[[ARG0:.*]]: index,
137+
// ALL-SAME: %[[ARG1:.*]]: memref<2xi8>) {
138+
func.func @view_memref_as_rank0(%offset: index, %mem: memref<2xi8>) {
139+
140+
// ALL: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : index to i64
141+
// ALL: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2xi8> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
142+
// ALL: %[[VAL_2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
143+
// ALL: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
144+
// ALL: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][0] : !llvm.struct<(ptr, ptr, i64)>
145+
// ALL: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
146+
// ALL: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_5]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
147+
// ALL: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_4]][1] : !llvm.struct<(ptr, ptr, i64)>
148+
// ALL: %[[VAL_8:.*]] = llvm.mlir.constant(0 : index) : i64
149+
// ALL: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_7]][2] : !llvm.struct<(ptr, ptr, i64)>
150+
%memref_view_bf16 = memref.view %mem[%offset][] : memref<2xi8> to memref<bf16>
151+
152+
return
153+
}
154+
155+
// -----
156+
135157
// Subviews needs to be expanded outside of the memref-to-llvm pass.
136158
// CHECK-LABEL: func @subview(
137159
// CHECK: %[[MEMREF:.*]]: memref<{{.*}}>,

0 commit comments

Comments
 (0)