From 29d5a08890dff87b938fced7db5a91b80b2fa217 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Tue, 14 Oct 2025 10:20:29 -0400 Subject: [PATCH] Revert "[mlir] Add strided metadata range dataflow analysis (#161280)" This reverts commit aa8499863ad23350da0912d99d189f306d0ea139. --- .../DataFlow/StridedMetadataRangeAnalysis.h | 54 ------- mlir/include/mlir/Dialect/MemRef/IR/MemRef.h | 1 - .../mlir/Dialect/MemRef/IR/MemRefOps.td | 2 - mlir/include/mlir/Interfaces/CMakeLists.txt | 1 - .../mlir/Interfaces/InferIntRangeInterface.h | 12 +- .../InferStridedMetadataInterface.h | 145 ------------------ .../InferStridedMetadataInterface.td | 45 ------ mlir/lib/Analysis/CMakeLists.txt | 1 - .../DataFlow/StridedMetadataRangeAnalysis.cpp | 127 --------------- mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | 3 +- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 59 ------- mlir/lib/Interfaces/CMakeLists.txt | 2 - .../lib/Interfaces/InferIntRangeInterface.cpp | 19 --- .../InferStridedMetadataInterface.cpp | 36 ----- .../test-strided-metadata-range-analysis.mlir | 67 -------- mlir/test/lib/Analysis/CMakeLists.txt | 1 - .../TestStridedMetadataRangeAnalysis.cpp | 86 ----------- mlir/tools/mlir-opt/mlir-opt.cpp | 2 - 18 files changed, 2 insertions(+), 661 deletions(-) delete mode 100644 mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h delete mode 100644 mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h delete mode 100644 mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td delete mode 100644 mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp delete mode 100644 mlir/lib/Interfaces/InferStridedMetadataInterface.cpp delete mode 100644 mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir delete mode 100644 mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp diff --git a/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h deleted file mode 100644 index 72ac2477435db..0000000000000 --- a/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h +++ /dev/null @@ -1,54 +0,0 @@ -//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H -#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H - -#include "mlir/Analysis/DataFlow/SparseAnalysis.h" -#include "mlir/Interfaces/InferStridedMetadataInterface.h" - -namespace mlir { -namespace dataflow { - -/// This lattice element represents the strided metadata of an SSA value. -class StridedMetadataRangeLattice : public Lattice { -public: - using Lattice::Lattice; -}; - -/// Strided metadata range analysis determines the strided metadata ranges of -/// SSA values using operations that define `InferStridedMetadataInterface`. -/// -/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and -/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not -/// loaded in the same solver context. -class StridedMetadataRangeAnalysis - : public SparseForwardDataFlowAnalysis { -public: - StridedMetadataRangeAnalysis(DataFlowSolver &solver, - int32_t indexBitwidth = 64); - - /// At an entry point, we cannot reason about strided metadata ranges unless - /// the type also encodes the data. For example, a memref with static layout. - void setToEntryState(StridedMetadataRangeLattice *lattice) override; - - /// Visit an operation. Invoke the transfer function on each operation that - /// implements `InferStridedMetadataInterface`. - LogicalResult - visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) override; - -private: - /// Index bitwidth to use when operating with the int-ranges. - int32_t indexBitwidth = 64; -}; -} // namespace dataflow -} // end namespace mlir - -#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h index 69447f74ec403..30f33ed2fd1d6 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -17,7 +17,6 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" -#include "mlir/Interfaces/InferStridedMetadataInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/MemOpInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index b39207fc30dd7..89bd0f103d9f3 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -14,7 +14,6 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferIntRangeInterface.td" -include "mlir/Interfaces/InferStridedMetadataInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/MemOpInterfaces.td" include "mlir/Interfaces/MemorySlotInterfaces.td" @@ -2086,7 +2085,6 @@ def MemRef_StoreOp : MemRef_Op<"store", def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AttrSizedOperandSegments, diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index 72ed046a1ba5d..a5feb592045c0 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -6,7 +6,6 @@ add_mlir_interface(DestinationStyleOpInterface) add_mlir_interface(FunctionInterfaces) add_mlir_interface(IndexingMapOpInterface) add_mlir_interface(InferIntRangeInterface) -add_mlir_interface(InferStridedMetadataInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(MemOpInterfaces) diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h index a6de3d1885eec..0e107e88f5232 100644 --- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h @@ -117,8 +117,7 @@ class IntegerValueRange { IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {} /// Create an integer value range lattice value. - explicit IntegerValueRange( - std::optional value = std::nullopt) + IntegerValueRange(std::optional value = std::nullopt) : value(std::move(value)) {} /// Whether the range is uninitialized. This happens when the state hasn't @@ -168,15 +167,6 @@ using SetIntRangeFn = using SetIntLatticeFn = llvm::function_ref; -/// Helper callback type to get the integer range of a value. -using GetIntRangeFn = function_ref; - -/// Helper function to collect the integer range values of an array of op fold -/// results. -SmallVector getIntValueRanges(ArrayRef values, - GetIntRangeFn getIntRange, - int32_t indexBitwidth); - class InferIntRangeInterface; namespace intrange::detail { diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h deleted file mode 100644 index 0c572e0196a03..0000000000000 --- a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h +++ /dev/null @@ -1,145 +0,0 @@ -//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains definitions of the strided metadata inference interface -// defined in `InferStridedMetadataInterface.td` -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H -#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H - -#include "mlir/Interfaces/InferIntRangeInterface.h" - -namespace mlir { -/// A class that represents the strided metadata range information, including -/// offsets, sizes, and strides as integer ranges. -class StridedMetadataRange { -public: - /// Default constructor creates uninitialized ranges. - StridedMetadataRange() = default; - - /// Returns a ranked strided metadata range. - static StridedMetadataRange - getRanked(SmallVectorImpl &&offsets, - SmallVectorImpl &&sizes, - SmallVectorImpl &&strides) { - return StridedMetadataRange(std::move(offsets), std::move(sizes), - std::move(strides)); - } - - /// Returns a strided metadata range with maximum ranges. - static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, - int32_t offsetsRank, - int32_t sizeRank, - int32_t stridedRank) { - return StridedMetadataRange( - SmallVector( - offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)), - SmallVector( - sizeRank, ConstantIntRanges::maxRange(indexBitwidth)), - SmallVector( - stridedRank, ConstantIntRanges::maxRange(indexBitwidth))); - } - - static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, - int32_t rank) { - return getMaxRanges(indexBitwidth, 1, rank, rank); - } - - /// Returns whether the metadata is uninitialized. - bool isUninitialized() const { return !offsets.has_value(); } - - /// Get the offsets range. - ArrayRef getOffsets() const { - return offsets ? *offsets : ArrayRef(); - } - MutableArrayRef getOffsets() { - return offsets ? *offsets : MutableArrayRef(); - } - - /// Get the sizes ranges. - ArrayRef getSizes() const { return sizes; } - MutableArrayRef getSizes() { return sizes; } - - /// Get the strides ranges. - ArrayRef getStrides() const { return strides; } - MutableArrayRef getStrides() { return strides; } - - /// Compare two strided metadata ranges. - bool operator==(const StridedMetadataRange &other) const { - return offsets == other.offsets && sizes == other.sizes && - strides == other.strides; - } - - /// Print the strided metadata range. - void print(raw_ostream &os) const; - - /// Join two strided metadata ranges, by taking the element-wise union of the - /// metadata. - static StridedMetadataRange join(const StridedMetadataRange &lhs, - const StridedMetadataRange &rhs) { - if (lhs.isUninitialized()) - return rhs; - if (rhs.isUninitialized()) - return lhs; - - // Helper fuction to compute the range union of constant ranges. - auto rangeUnion = - +[](const std::tuple &lhsRhs) - -> ConstantIntRanges { - return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs)); - }; - - // Get the elementwise range union. Note, that `zip_equal` will assert if - // sizes are not equal. - SmallVector offsets = llvm::map_to_vector( - llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion); - SmallVector sizes = - llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion); - SmallVector strides = llvm::map_to_vector( - llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion); - - // Return the joined metadata. - return StridedMetadataRange(std::move(offsets), std::move(sizes), - std::move(strides)); - } - -private: - /// Create a strided metadata range with the given offset, sizes, and strides. - StridedMetadataRange(SmallVectorImpl &&offsets, - SmallVectorImpl &&sizes, - SmallVectorImpl &&strides) - : offsets(std::move(offsets)), sizes(std::move(sizes)), - strides(std::move(strides)) {} - - /// The offsets range. - std::optional> offsets; - - /// The sizes ranges. - SmallVector sizes; - - /// The strides ranges. - SmallVector strides; -}; - -/// Print the strided metadata to `os`. -inline raw_ostream &operator<<(raw_ostream &os, - const StridedMetadataRange &range) { - range.print(os); - return os; -} - -/// Callback function type for setting the strided metadata of a value. -using SetStridedMetadataRangeFn = - function_ref; -} // end namespace mlir - -#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc" - -#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td deleted file mode 100644 index ee5b0942f683e..0000000000000 --- a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td +++ /dev/null @@ -1,45 +0,0 @@ -//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Defines the interface for strided metadata range analysis -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE -#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE - -include "mlir/IR/OpBase.td" - -def InferStridedMetadataOpInterface : - OpInterface<"InferStridedMetadataOpInterface"> { - let description = [{ - Allows operations to participate in strided metadata analysis by providing - methods that allow them to specify bounds on offsets, sizes, and strides - of their result(s) given bounds on their input(s) if known. - }]; - let cppNamespace = "::mlir"; - - let methods = [ - InterfaceMethod<[{ - Infer the strided metadata bounds on the results of this op given - the bounds on its operands. - For each result value or block argument of interest, the method should - call `setMetadata` with that `Value` as an argument. - The `operands` parameter contains the strided metadata ranges for all the - operands of the operation in order. - The `getIntRange` callback is provided for obtaining the int-range - analysis result for a given value. - }], - "void", "inferStridedMetadataRanges", - (ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands, - "::mlir::GetIntRangeFn":$getIntRange, - "::mlir::SetStridedMetadataRangeFn":$setMetadata, - "int32_t":$indexBitwidth)> - ]; -} -#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index bef189600d8e7..609cb34309829 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -40,7 +40,6 @@ add_mlir_library(MLIRAnalysis DataFlow/IntegerRangeAnalysis.cpp DataFlow/LivenessAnalysis.cpp DataFlow/SparseAnalysis.cpp - DataFlow/StridedMetadataRangeAnalysis.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp deleted file mode 100644 index 01c9dafaddf10..0000000000000 --- a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp +++ /dev/null @@ -1,127 +0,0 @@ -//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++ -//-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the dataflow analysis class for integer range inference -// which is used in transformations over the `arith` dialect such as -// branch elimination or signed->unsigned rewriting -// -//===----------------------------------------------------------------------===// - -#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h" -#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/DebugStringHelper.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/DebugLog.h" - -#define DEBUG_TYPE "strided-metadata-range-analysis" - -using namespace mlir; -using namespace mlir::dataflow; - -/// Get the entry state for a value. For any value that is not a ranked memref, -/// this function sets the metadata to a top state with no offsets, sizes, or -/// strides. For `memref` types, this function will use the metadata in the type -/// to try to deduce as much informaiton as possible. -static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) { - // TODO: generalize this method with a type interface. - auto mTy = dyn_cast(v.getType()); - - // If not a memref or it's un-ranked, don't infer any metadata. - if (!mTy || !mTy.hasRank()) - return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0); - - // Get the top state. - auto metadata = - StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank()); - - // Compute the offset and strides. - int64_t offset; - SmallVector strides; - if (failed(cast(mTy).getStridesAndOffset(strides, offset))) - return metadata; - - // Refine the metadata if we know it from the type. - if (!ShapedType::isDynamic(offset)) { - metadata.getOffsets()[0] = - ConstantIntRanges::constant(APInt(indexBitwidth, offset)); - } - for (auto &&[size, range] : - llvm::zip_equal(mTy.getShape(), metadata.getSizes())) { - if (ShapedType::isDynamic(size)) - continue; - range = ConstantIntRanges::constant(APInt(indexBitwidth, size)); - } - for (auto &&[stride, range] : - llvm::zip_equal(strides, metadata.getStrides())) { - if (ShapedType::isDynamic(stride)) - continue; - range = ConstantIntRanges::constant(APInt(indexBitwidth, stride)); - } - - return metadata; -} - -StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis( - DataFlowSolver &solver, int32_t indexBitwidth) - : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) { - assert(indexBitwidth > 0 && "invalid bitwidth"); -} - -void StridedMetadataRangeAnalysis::setToEntryState( - StridedMetadataRangeLattice *lattice) { - propagateIfChanged(lattice, lattice->join(getEntryStateImpl( - lattice->getAnchor(), indexBitwidth))); -} - -LogicalResult StridedMetadataRangeAnalysis::visitOperation( - Operation *op, ArrayRef operands, - ArrayRef results) { - auto inferrable = dyn_cast(op); - - // Bail if we cannot reason about the op. - if (!inferrable) { - setAllToEntryStates(results); - return success(); - } - - LDBG() << "Inferring metadata for: " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - - // Helper function to retrieve int range values. - auto getIntRange = [&](Value value) -> IntegerValueRange { - auto lattice = getOrCreateFor( - getProgramPointAfter(op), value); - return lattice ? lattice->getValue() : IntegerValueRange(); - }; - - // Convert the arguments lattices to a vector. - SmallVector argRanges = llvm::map_to_vector( - operands, [](const StridedMetadataRangeLattice *lattice) { - return lattice->getValue(); - }); - - // Callback to set metadata on a result. - auto joinCallback = [&](Value v, const StridedMetadataRange &md) { - auto result = cast(v); - assert(llvm::is_contained(op->getResults(), result)); - LDBG() << "- Inferred metadata: " << md; - StridedMetadataRangeLattice *lattice = results[result.getResultNumber()]; - ChangeResult changed = lattice->join(md); - LDBG() << "- Joined metadata: " << lattice->getValue(); - propagateIfChanged(lattice, changed); - }; - - // Infer the metadata. - inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback, - indexBitwidth); - return success(); -} diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index 1382c7aceea79..e25a0121a3359 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect DEPENDS MLIRMemRefOpsIncGen @@ -18,7 +18,6 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRDialectUtils MLIRInferIntRangeCommon MLIRInferIntRangeInterface - MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRIR MLIRMemOpInterfaces diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 507597b4707c4..e9bdcda296da5 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3437,65 +3437,6 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) { return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); } -void SubViewOp::inferStridedMetadataRanges( - ArrayRef ranges, GetIntRangeFn getIntRange, - SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) { - auto isUninitialized = - +[](IntegerValueRange range) { return range.isUninitialized(); }; - - // Bail early if any of the operands metadata is not ready: - SmallVector offsetOperands = - getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth); - if (llvm::any_of(offsetOperands, isUninitialized)) - return; - - SmallVector sizeOperands = - getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth); - if (llvm::any_of(sizeOperands, isUninitialized)) - return; - - SmallVector stridesOperands = - getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth); - if (llvm::any_of(stridesOperands, isUninitialized)) - return; - - StridedMetadataRange sourceRange = - ranges[getSourceMutable().getOperandNumber()]; - if (sourceRange.isUninitialized()) - return; - - ArrayRef srcStrides = sourceRange.getStrides(); - - // Get the dropped dims. - llvm::SmallBitVector droppedDims = getDroppedDims(); - - // Compute the new offset, strides and sizes. - ConstantIntRanges offset = sourceRange.getOffsets()[0]; - SmallVector strides, sizes; - - for (size_t i = 0, e = droppedDims.size(); i < e; ++i) { - bool dropped = droppedDims.test(i); - // Compute the new offset. - ConstantIntRanges off = - intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]}); - offset = intrange::inferAdd({offset, off}); - - // Skip dropped dimensions. - if (dropped) - continue; - // Multiply the strides. - strides.push_back( - intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]})); - // Get the sizes. - sizes.push_back(sizeOperands[i].getValue()); - } - - setMetadata(getResult(), - StridedMetadataRange::getRanked( - SmallVector({std::move(offset)}), - std::move(sizes), std::move(strides))); -} - //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index ad020eb431ee0..388de1c3e5abf 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -9,7 +9,6 @@ set(LLVM_OPTIONAL_SOURCES FunctionInterfaces.cpp IndexingMapOpInterface.cpp InferIntRangeInterface.cpp - InferStridedMetadataInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp MemOpInterfaces.cpp @@ -65,7 +64,6 @@ add_mlir_library(MLIRFunctionInterfaces add_mlir_interface_library(IndexingMapOpInterface) add_mlir_interface_library(InferIntRangeInterface) -add_mlir_interface_library(InferStridedMetadataInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_library(MLIRLoopLikeInterface diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp index 84fc9b8b61a11..9f3e97d051c85 100644 --- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -146,25 +146,6 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) { return os; } -SmallVector -mlir::getIntValueRanges(ArrayRef values, - GetIntRangeFn getIntRange, int32_t indexBitwidth) { - SmallVector ranges; - ranges.reserve(values.size()); - for (OpFoldResult ofr : values) { - if (auto value = dyn_cast(ofr)) { - ranges.push_back(getIntRange(value)); - continue; - } - - // Create a constant range. - auto attr = cast(cast(ofr)); - ranges.emplace_back(ConstantIntRanges::constant( - attr.getValue().sextOrTrunc(indexBitwidth))); - } - return ranges; -} - void mlir::intrange::detail::defaultInferResultRanges( InferIntRangeInterface interface, ArrayRef argRanges, SetIntLatticeFn setResultRanges) { diff --git a/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp deleted file mode 100644 index 483e9f192cdcd..0000000000000 --- a/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp +++ /dev/null @@ -1,36 +0,0 @@ -//===- InferStridedMetadataInterface.cpp - Strided md inference interface -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Interfaces/InferStridedMetadataInterface.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include - -using namespace mlir; - -#include "mlir/Interfaces/InferStridedMetadataInterface.cpp.inc" - -void StridedMetadataRange::print(raw_ostream &os) const { - if (isUninitialized()) { - os << "strided_metadata"; - return; - } - os << "strided_metadata"; -} diff --git a/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir b/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir deleted file mode 100644 index 808c1c2bfd2a8..0000000000000 --- a/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir +++ /dev/null @@ -1,67 +0,0 @@ -// RUN: mlir-opt -test-strided-metadata-range-analysis %s 2>&1 | FileCheck %s - -func.func @memref_subview(%arg0: memref<8x16x4xf32, strided<[64, 4, 1]>>, %arg1: memref<1x128x1x32x1xf32, strided<[4096, 32, 32, 1, 1]>>, %arg2: memref<8x16x4xf32, strided<[1, 64, 8], offset: 16>>, %arg3: index, %arg4: index, %arg5: index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %0 = test.with_bounds {smax = 13 : index, smin = 11 : index, umax = 13 : index, umin = 11 : index} : index - %1 = test.with_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index} : index - - // Test subview with unknown sizes, and constant offsets and strides. - // CHECK: Op: %[[SV0:.*]] = memref.subview - // CHECK-NEXT: result[0]: strided_metadata< - // CHECK-SAME: offset = [{unsigned : [1, 1] signed : [1, 1]}] - // CHECK-SAME: sizes = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] - // CHECK-SAME: strides = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [4, 4] signed : [4, 4]}, {unsigned : [1, 1] signed : [1, 1]}] - %subview = memref.subview %arg0[%c0, %c0, %c1] [%arg3, %arg4, %arg5] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[64, 4, 1]>> to memref> - - // Test a subview of a subview, with bounded dynamic offsets. - // CHECK: Op: %[[SV1:.*]] = memref.subview - // CHECK-NEXT: result[0]: strided_metadata< - // CHECK-SAME: offset = [{unsigned : [346, 484] signed : [346, 484]}] - // CHECK-SAME: sizes = [{unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}] - // CHECK-SAME: strides = [{unsigned : [704, 832] signed : [704, 832]}, {unsigned : [44, 52] signed : [44, 52]}, {unsigned : [11, 13] signed : [11, 13]}] - %subview_0 = memref.subview %subview[%1, %1, %1] [%c2, %c2, %c2] [%0, %0, %0] : memref> to memref> - - // Test a subview of a subview, with constant operands. - // CHECK: Op: %[[SV2:.*]] = memref.subview - // CHECK-NEXT: result[0]: strided_metadata< - // CHECK-SAME: offset = [{unsigned : [368, 510] signed : [368, 510]}] - // CHECK-SAME: sizes = [{unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}] - // CHECK-SAME: strides = [{unsigned : [704, 832] signed : [704, 832]}, {unsigned : [44, 52] signed : [44, 52]}, {unsigned : [11, 13] signed : [11, 13]}] - %subview_1 = memref.subview %subview_0[%c0, %c0, %c2] [%c2, %c2, %c2] [%c1, %c1, %c1] : memref> to memref> - - // Test a rank-reducing subview. - // CHECK: Op: %[[SV3:.*]] = memref.subview - // CHECK-NEXT: result[0]: strided_metadata< - // CHECK-SAME: offset = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] - // CHECK-SAME: sizes = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [16, 16] signed : [16, 16]}] - // CHECK-SAME: strides = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] - %subview_2 = memref.subview %arg1[%arg4, %arg4, %arg4, %arg4, %arg4] [1, 64, 1, 16, 1] [%arg5, %arg5, %arg5, %arg5, %arg5] : memref<1x128x1x32x1xf32, strided<[4096, 32, 32, 1, 1]>> to memref<64x16xf32, strided<[?, ?], offset: ?>> - - // Test a subview of a rank-reducing subview - // CHECK: Op: %[[SV4:.*]] = memref.subview - // CHECK-NEXT: result[0]: strided_metadata< - // CHECK-SAME: offset = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] - // CHECK-SAME: sizes = [{unsigned : [5, 7] signed : [5, 7]}] - // CHECK-SAME: strides = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] - %subview_3 = memref.subview %subview_2[%c0, %0] [1, %1] [%c1, %c2] : memref<64x16xf32, strided<[?, ?], offset: ?>> to memref> - - // Test a subview with mixed bounded and unbound dynamic sizes. - // CHECK: Op: %[[SV5:.*]] = memref.subview - // CHECK-NEXT: result[0]: strided_metadata< - // CHECK-SAME: offset = [{unsigned : [32, 32] signed : [32, 32]}] - // CHECK-SAME: sizes = [{unsigned : [11, 13] signed : [11, 13]}, {unsigned : [5, 7] signed : [5, 7]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] - // CHECK-SAME: strides = [{unsigned : [1, 1] signed : [1, 1]}, {unsigned : [64, 64] signed : [64, 64]}, {unsigned : [8, 8] signed : [8, 8]}] - %subview_4 = memref.subview %arg2[%c0, %c0, %c2] [%0, %1, %arg5] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[1, 64, 8], offset: 16>> to memref> - return -} - -// CHECK: func.func @memref_subview -// CHECK: %[[A0:.*]]: memref<8x16x4xf32, strided<[64, 4, 1]>> -// CHECK: %[[SV0]] = memref.subview %[[A0]] -// CHECK-NEXT: %[[SV1]] = memref.subview -// CHECK-NEXT: %[[SV2]] = memref.subview -// CHECK-NEXT: %[[SV3]] = memref.subview -// CHECK-NEXT: %[[SV4]] = memref.subview -// CHECK-NEXT: %[[SV5]] = memref.subview diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt index c37671ade37b3..91879981bffd2 100644 --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -17,7 +17,6 @@ add_mlir_library(MLIRTestAnalysis DataFlow/TestDenseForwardDataFlowAnalysis.cpp DataFlow/TestLivenessAnalysis.cpp DataFlow/TestSparseBackwardDataFlowAnalysis.cpp - DataFlow/TestStridedMetadataRangeAnalysis.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp deleted file mode 100644 index 6ac09fdeed136..0000000000000 --- a/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp +++ /dev/null @@ -1,86 +0,0 @@ -//===- TestStridedMetadataRangeAnalysis.cpp - Test strided md analysis ----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" -#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" -#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" -#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h" -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Operation.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; -using namespace mlir::dataflow; - -static void printAnalysisResults(DataFlowSolver &solver, Operation *op, - raw_ostream &os) { - // Collect the strided metadata of the op results. - SmallVector> results; - for (OpResult result : op->getResults()) { - const auto *state = solver.lookupState(result); - // Skip the result if it's uninitialized. - if (!state || state->getValue().isUninitialized()) - continue; - - // Skip the result if the range is empty. - const mlir::StridedMetadataRange &md = state->getValue(); - if (md.getOffsets().empty() && md.getSizes().empty() && - md.getStrides().empty()) - continue; - results.push_back({result.getResultNumber(), state}); - } - - // Early exit if there's no metadata to print. - if (results.empty()) - return; - - // Print the metadata. - os << "Op: " << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "\n"; - for (auto [idx, state] : results) - os << " result[" << idx << "]: " << state->getValue() << "\n"; - os << "\n"; -} - -namespace { -struct TestStridedMetadataRangeAnalysisPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestStridedMetadataRangeAnalysisPass) - - StringRef getArgument() const override { - return "test-strided-metadata-range-analysis"; - } - void runOnOperation() override { - Operation *op = getOperation(); - - DataFlowSolver solver; - solver.load(); - solver.load(); - solver.load(); - solver.load(); - if (failed(solver.initializeAndRun(op))) - return signalPassFailure(); - - op->walk( - [&](Operation *op) { printAnalysisResults(solver, op, llvm::errs()); }); - } -}; -} // end anonymous namespace - -namespace mlir { -namespace test { -void registerTestStridedMetadataRangeAnalysisPass() { - PassRegistration(); -} -} // end namespace test -} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 88421800fed1e..6432fae615f88 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -151,7 +151,6 @@ void registerTestSliceAnalysisPass(); void registerTestSPIRVCPURunnerPipeline(); void registerTestSPIRVFuncSignatureConversion(); void registerTestSPIRVVectorUnrolling(); -void registerTestStridedMetadataRangeAnalysisPass(); void registerTestTensorCopyInsertionPass(); void registerTestTensorLikeAndBufferLikePass(); void registerTestTensorTransforms(); @@ -300,7 +299,6 @@ void registerTestPasses() { mlir::test::registerTestSPIRVCPURunnerPipeline(); mlir::test::registerTestSPIRVFuncSignatureConversion(); mlir::test::registerTestSPIRVVectorUnrolling(); - mlir::test::registerTestStridedMetadataRangeAnalysisPass(); mlir::test::registerTestTensorCopyInsertionPass(); mlir::test::registerTestTensorLikeAndBufferLikePass(); mlir::test::registerTestTensorTransforms();