Skip to content

Commit 72da9b1

Browse files
jtuylsaokblast
authored andcommitted
[MemRef] Implement value bounds interface for ExpandShapeOp (llvm#164438)
1 parent 29e02f5 commit 72da9b1

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ struct DimOpInterface
5959
}
6060
};
6161

62+
struct ExpandShapeOpInterface
63+
: public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
64+
memref::ExpandShapeOp> {
65+
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
66+
ValueBoundsConstraintSet &cstr) const {
67+
auto expandOp = cast<memref::ExpandShapeOp>(op);
68+
assert(value == expandOp.getResult() && "invalid value");
69+
cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
70+
}
71+
};
72+
6273
struct GetGlobalOpInterface
6374
: public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
6475
GetGlobalOp> {
@@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
123134
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
124135
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
125136
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
137+
memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
138+
*ctx);
126139
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
127140
memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
128141
memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);

mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,20 @@ func.func @memref_dim_all_positive(%m: memref<?xf32>, %x: index) {
6363

6464
// -----
6565

66+
// CHECK-LABEL: func @memref_expand(
67+
// CHECK-SAME: %[[m:[a-zA-Z0-9]+]]: memref<?xf32>
68+
// CHECK-SAME: %[[sz:[a-zA-Z0-9]+]]: index
69+
// CHECK: %[[c4:.*]] = arith.constant 4 : index
70+
// CHECK: return %[[sz]], %[[c4]]
71+
func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) {
72+
%0 = memref.expand_shape %m [[0, 1]] output_shape [%sz, 4]: memref<?xf32> into memref<?x4xf32>
73+
%1 = "test.reify_bound"(%0) {dim = 0} : (memref<?x4xf32>) -> (index)
74+
%2 = "test.reify_bound"(%0) {dim = 1} : (memref<?x4xf32>) -> (index)
75+
return %1, %2 : index, index
76+
}
77+
78+
// -----
79+
6680
// CHECK-LABEL: func @memref_get_global(
6781
// CHECK: %[[c4:.*]] = arith.constant 4 : index
6882
// CHECK: return %[[c4]]

0 commit comments

Comments
 (0)