diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 2364f8957992d..8a708eb29210c 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2667,7 +2667,7 @@ AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) { if (!ivArg || !ivArg.getOwner()) return nullptr; Operation *containingOp = ivArg.getOwner()->getParentOp(); - auto parallelOp = dyn_cast(containingOp); + auto parallelOp = dyn_cast_if_present(containingOp); if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val)) return parallelOp; return nullptr; diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm-with-transforms.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm-with-transforms.mlir new file mode 100644 index 0000000000000..f6d0524fce39d --- /dev/null +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm-with-transforms.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt -test-memref-to-llvm-with-transforms %s | FileCheck %s + +// Checks that the program does not crash. The functionality of the pattern is +// already checked in test/Dialect/MemRef/*.mlir + +func.func @subview_folder(%arg0: memref<100x100xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> memref> { + %subview = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1] : memref<100x100xf32> to memref> + return %subview : memref> +} +// CHECK-LABEL: llvm.func @subview_folder diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt index c09496be729be..167cce225595b 100644 --- a/mlir/test/lib/Conversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(ConvertToSPIRV) add_subdirectory(FuncToLLVM) add_subdirectory(MathToVCIX) +add_subdirectory(MemRefToLLVM) add_subdirectory(VectorToSPIRV) diff --git a/mlir/test/lib/Conversion/MemRefToLLVM/CMakeLists.txt b/mlir/test/lib/Conversion/MemRefToLLVM/CMakeLists.txt new file mode 100644 index 0000000000000..580c9ca4a6049 --- /dev/null +++ b/mlir/test/lib/Conversion/MemRefToLLVM/CMakeLists.txt @@ -0,0 +1,22 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRTestMemRefToLLVMWithTransforms + TestMemRefToLLVMWithTransforms.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRTestDialect + ) +mlir_target_link_libraries(MLIRTestMemRefToLLVMWithTransforms PUBLIC + MLIRFuncToLLVM + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRMemRefTransforms + MLIRPass + ) + +target_include_directories(MLIRTestMemRefToLLVMWithTransforms + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test + ) diff --git a/mlir/test/lib/Conversion/MemRefToLLVM/TestMemRefToLLVMWithTransforms.cpp b/mlir/test/lib/Conversion/MemRefToLLVM/TestMemRefToLLVMWithTransforms.cpp new file mode 100644 index 0000000000000..af3b6608aea16 --- /dev/null +++ b/mlir/test/lib/Conversion/MemRefToLLVM/TestMemRefToLLVMWithTransforms.cpp @@ -0,0 +1,62 @@ +//===- TestMemRefToLLVMWithTransforms.cpp ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +struct TestMemRefToLLVMWithTransforms + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMemRefToLLVMWithTransforms) + + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + + StringRef getArgument() const final { + return "test-memref-to-llvm-with-transforms"; + } + + StringRef getDescription() const final { + return "Tests conversion of MemRef dialects + `func.func` to LLVM dialect " + "with MemRef transforms."; + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + LowerToLLVMOptions options(ctx); + LLVMTypeConverter typeConverter(ctx, options); + RewritePatternSet patterns(ctx); + memref::populateExpandStridedMetadataPatterns(patterns); + populateFuncToLLVMConversionPatterns(typeConverter, patterns); + LLVMConversionTarget target(getContext()); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +namespace mlir::test { +void registerTestMemRefToLLVMWithTransforms() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 34db3051d36a0..26d7597347a8a 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -28,6 +28,7 @@ if(MLIR_INCLUDE_TESTS) MLIRMathTestPasses MLIRTestMathToVCIX MLIRMemRefTestPasses + MLIRTestMemRefToLLVMWithTransforms MLIRMeshTest MLIRNVGPUTestPasses MLIRSCFTestPasses diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 2e08ae6f37980..6ef9ff8e84545 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -130,6 +130,7 @@ void registerTestMathToVCIXPass(); void registerTestIrdlTestDialectConversionPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); +void registerTestMemRefToLLVMWithTransforms(); void registerTestMeshReshardingSpmdizationPass(); void registerTestMeshSimplificationsPass(); void registerTestMultiBuffering(); @@ -275,6 +276,7 @@ void registerTestPasses() { mlir::test::registerTestMathToVCIXPass(); mlir::test::registerTestMemRefDependenceCheck(); mlir::test::registerTestMemRefStrideCalculation(); + mlir::test::registerTestMemRefToLLVMWithTransforms(); mlir::test::registerTestMeshReshardingSpmdizationPass(); mlir::test::registerTestMeshSimplificationsPass(); mlir::test::registerTestMultiBuffering();