@@ -51,68 +51,29 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
5151 return cast<Value>(in);
5252}
5353
54- static bool hasDynamicDim (ArrayRef<OpFoldResult> dims) {
55- for (auto &&dim : dims) {
56- auto constant = getConstantIntValue (dim);
57- if (!constant || *constant < 0 ) {
58- return true ;
59- }
60- }
61- return false ;
62- }
63-
64- static OpFoldResult computeStaticShape (OpBuilder &builder, Location loc,
65- ArrayRef<OpFoldResult> dims,
66- ArrayRef<OpFoldResult> strides) {
67- // max(dims[i] * strides[i]) for i = 0, 1, ..., n-1
68- int64_t maxSize = 1 ;
69- for (auto &&[dim, stride] : llvm::zip (dims, strides)) {
70- AffineExpr s0, s1;
71- bindSymbols (builder.getContext (), s0, s1);
72- OpFoldResult size = affine::makeComposedFoldedAffineApply (
73- builder, loc, s0 * s1, ArrayRef<OpFoldResult>{dim, stride});
74- auto constant = getConstantIntValue (size);
75- assert (constant && " expected constant value" );
76- maxSize = std::max (maxSize, *constant);
77- }
78- return builder.getIndexAttr (maxSize);
79- }
80-
81- static OpFoldResult computeDynamicShape (OpBuilder &builder, Location loc,
82- ArrayRef<OpFoldResult> dims,
83- ArrayRef<OpFoldResult> strides) {
84-
54+ // / Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
55+ // / span of the memref.
56+ static OpFoldResult computeSize (OpBuilder &builder, Location loc,
57+ ArrayRef<OpFoldResult> dims,
58+ ArrayRef<OpFoldResult> strides) {
59+ assert (dims.size () == strides.size () &&
60+ " number of dimensions and strides should be equal" );
8561 SmallVector<AffineExpr> symbols (2 * dims.size ());
8662 bindSymbolsList (builder.getContext (), MutableArrayRef{symbols});
8763 SmallVector<AffineExpr> productExpressions;
88- SmallVector<Value > values;
64+ SmallVector<OpFoldResult > values;
8965 size_t symbolIndex = 0 ;
9066 for (auto &&[dim, stride] : llvm::zip (dims, strides)) {
9167 AffineExpr dimExpr = symbols[symbolIndex++];
9268 AffineExpr strideExpr = symbols[symbolIndex++];
9369 productExpressions.push_back (dimExpr * strideExpr);
94- values.push_back (getValueFromOpFoldResult (builder, loc, dim) );
95- values.push_back (getValueFromOpFoldResult (builder, loc, stride) );
70+ values.push_back (dim);
71+ values.push_back (stride);
9672 }
9773
9874 AffineMap maxMap = AffineMap::get (0 , symbols.size (), productExpressions,
9975 builder.getContext ());
100- Value maxValue =
101- builder.create <affine::AffineMaxOp>(loc, maxMap, values).getResult ();
102- return maxValue;
103- }
104-
105- // / Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
106- // / span of the memref.
107- static OpFoldResult computeSize (OpBuilder &builder, Location loc,
108- ArrayRef<OpFoldResult> dims,
109- ArrayRef<OpFoldResult> strides) {
110- assert (dims.size () == strides.size () &&
111- " number of dimensions and strides should be equal" );
112- if (hasDynamicDim (dims) || hasDynamicDim (strides)) {
113- return computeDynamicShape (builder, loc, dims, strides);
114- }
115- return computeStaticShape (builder, loc, dims, strides);
76+ return affine::makeComposedFoldedAffineMax (builder, loc, maxMap, values);
11677}
11778
11879// / Returns a collapsed memref and the linearized index to access the element
0 commit comments