Skip to content

Commit b2797d9

Browse files
authored
Reland "[mlir] Add strided metadata range dataflow analysis" (#163403)" (#163408)
This relands commit aa84998. That commit was originally reverted because it caused failures in shared lib builds due to missing link dependencies. This patch relands the commit with the missing libs added. Signed-off-by: Fabian Mora <[email protected]>
1 parent 932a7d6 commit b2797d9

File tree

18 files changed

+676
-2
lines changed

18 files changed

+676
-2
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
10+
#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
11+
12+
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
13+
#include "mlir/Interfaces/InferStridedMetadataInterface.h"
14+
15+
namespace mlir {
16+
namespace dataflow {
17+
18+
/// This lattice element represents the strided metadata of an SSA value.
19+
class StridedMetadataRangeLattice : public Lattice<StridedMetadataRange> {
20+
public:
21+
using Lattice::Lattice;
22+
};
23+
24+
/// Strided metadata range analysis determines the strided metadata ranges of
25+
/// SSA values using operations that define `InferStridedMetadataInterface`.
26+
///
27+
/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and
28+
/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not
29+
/// loaded in the same solver context.
30+
class StridedMetadataRangeAnalysis
31+
: public SparseForwardDataFlowAnalysis<StridedMetadataRangeLattice> {
32+
public:
33+
StridedMetadataRangeAnalysis(DataFlowSolver &solver,
34+
int32_t indexBitwidth = 64);
35+
36+
/// At an entry point, we cannot reason about strided metadata ranges unless
37+
/// the type also encodes the data. For example, a memref with static layout.
38+
void setToEntryState(StridedMetadataRangeLattice *lattice) override;
39+
40+
/// Visit an operation. Invoke the transfer function on each operation that
41+
/// implements `InferStridedMetadataInterface`.
42+
LogicalResult
43+
visitOperation(Operation *op,
44+
ArrayRef<const StridedMetadataRangeLattice *> operands,
45+
ArrayRef<StridedMetadataRangeLattice *> results) override;
46+
47+
private:
48+
/// Index bitwidth to use when operating with the int-ranges.
49+
int32_t indexBitwidth = 64;
50+
};
51+
} // namespace dataflow
52+
} // end namespace mlir
53+
54+
#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Interfaces/CastInterfaces.h"
1818
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1919
#include "mlir/Interfaces/InferIntRangeInterface.h"
20+
#include "mlir/Interfaces/InferStridedMetadataInterface.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
2122
#include "mlir/Interfaces/MemOpInterfaces.h"
2223
#include "mlir/Interfaces/MemorySlotInterfaces.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
1414
include "mlir/Interfaces/CastInterfaces.td"
1515
include "mlir/Interfaces/ControlFlowInterfaces.td"
1616
include "mlir/Interfaces/InferIntRangeInterface.td"
17+
include "mlir/Interfaces/InferStridedMetadataInterface.td"
1718
include "mlir/Interfaces/InferTypeOpInterface.td"
1819
include "mlir/Interfaces/MemOpInterfaces.td"
1920
include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -2085,6 +2086,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
20852086

20862087
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20872088
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2089+
DeclareOpInterfaceMethods<InferStridedMetadataOpInterface>,
20882090
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
20892091
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
20902092
AttrSizedOperandSegments,

mlir/include/mlir/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_interface(DestinationStyleOpInterface)
66
add_mlir_interface(FunctionInterfaces)
77
add_mlir_interface(IndexingMapOpInterface)
88
add_mlir_interface(InferIntRangeInterface)
9+
add_mlir_interface(InferStridedMetadataInterface)
910
add_mlir_interface(InferTypeOpInterface)
1011
add_mlir_interface(LoopLikeInterface)
1112
add_mlir_interface(MemOpInterfaces)

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ class IntegerValueRange {
117117
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
118118

119119
/// Create an integer value range lattice value.
120-
IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
120+
explicit IntegerValueRange(
121+
std::optional<ConstantIntRanges> value = std::nullopt)
121122
: value(std::move(value)) {}
122123

123124
/// Whether the range is uninitialized. This happens when the state hasn't
@@ -167,6 +168,15 @@ using SetIntRangeFn =
167168
using SetIntLatticeFn =
168169
llvm::function_ref<void(Value, const IntegerValueRange &)>;
169170

171+
/// Helper callback type to get the integer range of a value.
172+
using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
173+
174+
/// Helper function to collect the integer range values of an array of op fold
175+
/// results.
176+
SmallVector<IntegerValueRange> getIntValueRanges(ArrayRef<OpFoldResult> values,
177+
GetIntRangeFn getIntRange,
178+
int32_t indexBitwidth);
179+
170180
class InferIntRangeInterface;
171181

172182
namespace intrange::detail {
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains definitions of the strided metadata inference interface
10+
// defined in `InferStridedMetadataInterface.td`
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
15+
#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
16+
17+
#include "mlir/Interfaces/InferIntRangeInterface.h"
18+
19+
namespace mlir {
20+
/// A class that represents the strided metadata range information, including
21+
/// offsets, sizes, and strides as integer ranges.
22+
class StridedMetadataRange {
23+
public:
24+
/// Default constructor creates uninitialized ranges.
25+
StridedMetadataRange() = default;
26+
27+
/// Returns a ranked strided metadata range.
28+
static StridedMetadataRange
29+
getRanked(SmallVectorImpl<ConstantIntRanges> &&offsets,
30+
SmallVectorImpl<ConstantIntRanges> &&sizes,
31+
SmallVectorImpl<ConstantIntRanges> &&strides) {
32+
return StridedMetadataRange(std::move(offsets), std::move(sizes),
33+
std::move(strides));
34+
}
35+
36+
/// Returns a strided metadata range with maximum ranges.
37+
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
38+
int32_t offsetsRank,
39+
int32_t sizeRank,
40+
int32_t stridedRank) {
41+
return StridedMetadataRange(
42+
SmallVector<ConstantIntRanges>(
43+
offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)),
44+
SmallVector<ConstantIntRanges>(
45+
sizeRank, ConstantIntRanges::maxRange(indexBitwidth)),
46+
SmallVector<ConstantIntRanges>(
47+
stridedRank, ConstantIntRanges::maxRange(indexBitwidth)));
48+
}
49+
50+
static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
51+
int32_t rank) {
52+
return getMaxRanges(indexBitwidth, 1, rank, rank);
53+
}
54+
55+
/// Returns whether the metadata is uninitialized.
56+
bool isUninitialized() const { return !offsets.has_value(); }
57+
58+
/// Get the offsets range.
59+
ArrayRef<ConstantIntRanges> getOffsets() const {
60+
return offsets ? *offsets : ArrayRef<ConstantIntRanges>();
61+
}
62+
MutableArrayRef<ConstantIntRanges> getOffsets() {
63+
return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>();
64+
}
65+
66+
/// Get the sizes ranges.
67+
ArrayRef<ConstantIntRanges> getSizes() const { return sizes; }
68+
MutableArrayRef<ConstantIntRanges> getSizes() { return sizes; }
69+
70+
/// Get the strides ranges.
71+
ArrayRef<ConstantIntRanges> getStrides() const { return strides; }
72+
MutableArrayRef<ConstantIntRanges> getStrides() { return strides; }
73+
74+
/// Compare two strided metadata ranges.
75+
bool operator==(const StridedMetadataRange &other) const {
76+
return offsets == other.offsets && sizes == other.sizes &&
77+
strides == other.strides;
78+
}
79+
80+
/// Print the strided metadata range.
81+
void print(raw_ostream &os) const;
82+
83+
/// Join two strided metadata ranges, by taking the element-wise union of the
84+
/// metadata.
85+
static StridedMetadataRange join(const StridedMetadataRange &lhs,
86+
const StridedMetadataRange &rhs) {
87+
if (lhs.isUninitialized())
88+
return rhs;
89+
if (rhs.isUninitialized())
90+
return lhs;
91+
92+
// Helper fuction to compute the range union of constant ranges.
93+
auto rangeUnion =
94+
+[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs)
95+
-> ConstantIntRanges {
96+
return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs));
97+
};
98+
99+
// Get the elementwise range union. Note, that `zip_equal` will assert if
100+
// sizes are not equal.
101+
SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector(
102+
llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion);
103+
SmallVector<ConstantIntRanges> sizes =
104+
llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion);
105+
SmallVector<ConstantIntRanges> strides = llvm::map_to_vector(
106+
llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion);
107+
108+
// Return the joined metadata.
109+
return StridedMetadataRange(std::move(offsets), std::move(sizes),
110+
std::move(strides));
111+
}
112+
113+
private:
114+
/// Create a strided metadata range with the given offset, sizes, and strides.
115+
StridedMetadataRange(SmallVectorImpl<ConstantIntRanges> &&offsets,
116+
SmallVectorImpl<ConstantIntRanges> &&sizes,
117+
SmallVectorImpl<ConstantIntRanges> &&strides)
118+
: offsets(std::move(offsets)), sizes(std::move(sizes)),
119+
strides(std::move(strides)) {}
120+
121+
/// The offsets range.
122+
std::optional<SmallVector<ConstantIntRanges>> offsets;
123+
124+
/// The sizes ranges.
125+
SmallVector<ConstantIntRanges> sizes;
126+
127+
/// The strides ranges.
128+
SmallVector<ConstantIntRanges> strides;
129+
};
130+
131+
/// Print the strided metadata to `os`.
132+
inline raw_ostream &operator<<(raw_ostream &os,
133+
const StridedMetadataRange &range) {
134+
range.print(os);
135+
return os;
136+
}
137+
138+
/// Callback function type for setting the strided metadata of a value.
139+
using SetStridedMetadataRangeFn =
140+
function_ref<void(Value, const StridedMetadataRange &)>;
141+
} // end namespace mlir
142+
143+
#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc"
144+
145+
#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines the interface for strided metadata range analysis
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
14+
#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
def InferStridedMetadataOpInterface :
19+
OpInterface<"InferStridedMetadataOpInterface"> {
20+
let description = [{
21+
Allows operations to participate in strided metadata analysis by providing
22+
methods that allow them to specify bounds on offsets, sizes, and strides
23+
of their result(s) given bounds on their input(s) if known.
24+
}];
25+
let cppNamespace = "::mlir";
26+
27+
let methods = [
28+
InterfaceMethod<[{
29+
Infer the strided metadata bounds on the results of this op given
30+
the bounds on its operands.
31+
For each result value or block argument of interest, the method should
32+
call `setMetadata` with that `Value` as an argument.
33+
The `operands` parameter contains the strided metadata ranges for all the
34+
operands of the operation in order.
35+
The `getIntRange` callback is provided for obtaining the int-range
36+
analysis result for a given value.
37+
}],
38+
"void", "inferStridedMetadataRanges",
39+
(ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands,
40+
"::mlir::GetIntRangeFn":$getIntRange,
41+
"::mlir::SetStridedMetadataRangeFn":$setMetadata,
42+
"int32_t":$indexBitwidth)>
43+
];
44+
}
45+
#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE

mlir/lib/Analysis/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis
4040
DataFlow/IntegerRangeAnalysis.cpp
4141
DataFlow/LivenessAnalysis.cpp
4242
DataFlow/SparseAnalysis.cpp
43+
DataFlow/StridedMetadataRangeAnalysis.cpp
4344

4445
ADDITIONAL_HEADER_DIRS
4546
${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis
@@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis
5354
MLIRDataLayoutInterfaces
5455
MLIRFunctionInterfaces
5556
MLIRInferIntRangeInterface
57+
MLIRInferStridedMetadataInterface
5658
MLIRInferTypeOpInterface
5759
MLIRLoopLikeInterface
5860
MLIRPresburger

0 commit comments

Comments
 (0)