Skip to content

Commit ff7bed9

Browse files
jtuylsaokblast
authored andcommitted
[MemRef] Implement value bounds interface for CollapseShapeOp (llvm#164955)
1 parent 8b2e518 commit ff7bed9

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,27 @@ struct RankOpInterface
9898
}
9999
};
100100

101+
struct CollapseShapeOpInterface
102+
: public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
103+
memref::CollapseShapeOp> {
104+
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
105+
ValueBoundsConstraintSet &cstr) const {
106+
auto collapseOp = cast<memref::CollapseShapeOp>(op);
107+
assert(value == collapseOp.getResult() && "invalid value");
108+
109+
// Multiply the expressions for the dimensions in the reassociation group.
110+
const ReassociationIndices &reassocIndices =
111+
collapseOp.getReassociationIndices()[dim];
112+
AffineExpr productExpr =
113+
cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
114+
for (size_t i = 1; i < reassocIndices.size(); ++i) {
115+
productExpr =
116+
productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
117+
}
118+
cstr.bound(value)[dim] == productExpr;
119+
}
120+
};
121+
101122
struct SubViewOpInterface
102123
: public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
103124
SubViewOp> {
@@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
134155
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
135156
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
136157
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
158+
memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
159+
*ctx);
137160
memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
138161
*ctx);
139162
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);

mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,24 @@ func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) {
7777

7878
// -----
7979

80+
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
81+
// CHECK-LABEL: func @memref_collapse(
82+
// CHECK-SAME: %[[sz0:.*]]: index
83+
// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
84+
// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index
85+
// CHECK: %[[dim:.*]] = memref.dim %{{.*}}, %[[c2]] : memref<3x4x?x2xf32>
86+
// CHECK: %[[mul:.*]] = affine.apply #[[$MAP]]()[%[[dim]]]
87+
// CHECK: return %[[c12]], %[[mul]]
88+
func.func @memref_collapse(%sz0: index) -> (index, index) {
89+
%0 = memref.alloc(%sz0) : memref<3x4x?x2xf32>
90+
%1 = memref.collapse_shape %0 [[0, 1], [2, 3]] : memref<3x4x?x2xf32> into memref<12x?xf32>
91+
%2 = "test.reify_bound"(%1) {dim = 0} : (memref<12x?xf32>) -> (index)
92+
%3 = "test.reify_bound"(%1) {dim = 1} : (memref<12x?xf32>) -> (index)
93+
return %2, %3 : index, index
94+
}
95+
96+
// -----
97+
8098
// CHECK-LABEL: func @memref_get_global(
8199
// CHECK: %[[c4:.*]] = arith.constant 4 : index
82100
// CHECK: return %[[c4]]

0 commit comments

Comments
 (0)