diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 7f45904fab7e1..d78de25f3e0ff 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1711,6 +1711,10 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { MemRefDescriptor sourceMemRef(adaptor.getSource()); auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy); + // Early exit for 0-D corner case. + if (viewMemRefType.getRank() == 0) + return rewriter.replaceOp(viewOp, {targetMemRef}), success(); + // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); auto srcMemRefType = cast(viewOp.getSource().getType()); @@ -1733,10 +1737,6 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, offset)); - // Early exit for 0-D corner case. - if (viewMemRefType.getRank() == 0) - return rewriter.replaceOp(viewOp, {targetMemRef}), success(); - // Fields 4 and 5: Update sizes and strides. Value stride = nullptr, nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {