diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 339f8a61136fc..ade4e4d3de8ec 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1721,10 +1721,6 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { MemRefDescriptor sourceMemRef(adaptor.getSource()); auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy); - // Early exit for 0-D corner case. - if (viewMemRefType.getRank() == 0) - return rewriter.replaceOp(viewOp, {targetMemRef}), success(); - // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); auto srcMemRefType = cast(viewOp.getSource().getType()); @@ -1747,6 +1743,10 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, offset)); + // Early exit for 0-D corner case. + if (viewMemRefType.getRank() == 0) + return rewriter.replaceOp(viewOp, {targetMemRef}), success(); + // Fields 4 and 5: Update sizes and strides. Value stride = nullptr, nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 5538ddf8e4c3c..8c863bb2d3d65 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -1,10 +1,13 @@ -// RUN: mlir-opt -finalize-memref-to-llvm %s -split-input-file | FileCheck %s +// RUN: mlir-opt -finalize-memref-to-llvm %s -split-input-file | FileCheck --check-prefixes=ALL,CHECK %s // RUN: mlir-opt -finalize-memref-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. This produces slightly different IR // because the conversion target is set up differently. -// RUN: mlir-opt --convert-to-llvm="filter-dialects=memref" --split-input-file %s | FileCheck --check-prefix=CHECK-INTERFACE %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=memref" --split-input-file %s | FileCheck --check-prefixes=ALL,CHECK-INTERFACE %s + +// TODO: In some (all?) cases, CHECK and CHECK-INTERFACE outputs are identical. +// Use a common prefix instead (e.g. ALL). // CHECK-LABEL: func @view( // CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index @@ -132,6 +135,28 @@ func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) { // ----- +// ALL-LABEL: func.func @view_memref_as_rank0( +// ALL-SAME: %[[OFFSET:.*]]: index, +// ALL-SAME: %[[MEM:.*]]: memref<2xi8>) { +func.func @view_memref_as_rank0(%offset: index, %mem: memref<2xi8>) { + + // ALL: builtin.unrealized_conversion_cast %[[OFFSET]] : index to i64 + // ALL: builtin.unrealized_conversion_cast %[[MEM]] : memref<2xi8> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ALL: llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> + // ALL: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ALL: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> + // ALL: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // ALL: llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, i8 + // ALL: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> + // ALL: llvm.mlir.constant(0 : index) : i64 + // ALL: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> + %memref_view_bf16 = memref.view %mem[%offset][] : memref<2xi8> to memref + + return +} + +// ----- + // Subviews needs to be expanded outside of the memref-to-llvm pass. // CHECK-LABEL: func @subview( // CHECK: %[[MEMREF:.*]]: memref<{{.*}}>,