@@ -51,6 +51,70 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
5151 return cast<Value>(in);
5252}
5353
54+ static bool hasDynamicDim (ArrayRef<OpFoldResult> dims) {
55+ for (auto &&dim : dims) {
56+ auto constant = getConstantIntValue (dim);
57+ if (!constant || *constant < 0 ) {
58+ return true ;
59+ }
60+ }
61+ return false ;
62+ }
63+
64+ static OpFoldResult computeStaticShape (OpBuilder &builder, Location loc,
65+ ArrayRef<OpFoldResult> dims,
66+ ArrayRef<OpFoldResult> strides) {
67+ // max(dims[i] * strides[i]) for i = 0, 1, ..., n-1
68+ int64_t maxSize = 1 ;
69+ for (auto &&[dim, stride] : llvm::zip (dims, strides)) {
70+ AffineExpr s0, s1;
71+ bindSymbols (builder.getContext (), s0, s1);
72+ OpFoldResult size = affine::makeComposedFoldedAffineApply (
73+ builder, loc, s0 * s1, ArrayRef<OpFoldResult>{dim, stride});
74+ auto constant = getConstantIntValue (size);
75+ assert (constant && " expected constant value" );
76+ maxSize = *constant;
77+ }
78+ return builder.getIndexAttr (maxSize);
79+ }
80+
81+ static OpFoldResult computeDynamicShape (OpBuilder &builder, Location loc,
82+ ArrayRef<OpFoldResult> dims,
83+ ArrayRef<OpFoldResult> strides) {
84+
85+ SmallVector<AffineExpr> symbols (2 * dims.size ());
86+ bindSymbolsList (builder.getContext (), MutableArrayRef{symbols});
87+ SmallVector<AffineExpr> productExpressions;
88+ SmallVector<Value> values;
89+ size_t symbolIndex = 0 ;
90+ for (auto &&[dim, stride] : llvm::zip (dims, strides)) {
91+ AffineExpr dimExpr = symbols[symbolIndex++];
92+ AffineExpr strideExpr = symbols[symbolIndex++];
93+ productExpressions.push_back (dimExpr * strideExpr);
94+ values.push_back (getValueFromOpFoldResult (builder, loc, dim));
95+ values.push_back (getValueFromOpFoldResult (builder, loc, stride));
96+ }
97+
98+ AffineMap maxMap = AffineMap::get (0 , symbols.size (), productExpressions,
99+ builder.getContext ());
100+ Value maxValue =
101+ builder.create <affine::AffineMaxOp>(loc, maxMap, values).getResult ();
102+ return maxValue;
103+ }
104+
105+ // / Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
106+ // / span of the memref.
107+ static OpFoldResult computeSpan (OpBuilder &builder, Location loc,
108+ ArrayRef<OpFoldResult> dims,
109+ ArrayRef<OpFoldResult> strides) {
110+ assert (dims.size () == strides.size () &&
111+ " number of dimensions and strides should be equal" );
112+ if (hasDynamicDim (dims) || hasDynamicDim (strides)) {
113+ return computeDynamicShape (builder, loc, dims, strides);
114+ }
115+ return computeStaticShape (builder, loc, dims, strides);
116+ }
117+
54118// / Returns a collapsed memref and the linearized index to access the element
55119// / at the specified indices.
56120static std::pair<Value, Value> getFlattenMemrefAndOffset (OpBuilder &rewriter,
@@ -82,10 +146,12 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
82146 rewriter.create <memref::ReinterpretCastOp>(
83147 loc, source,
84148 /* offset = */ linearizedInfo.linearizedOffset ,
85- /* shapes = */ ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize },
149+ /* shapes = */
150+ ArrayRef<OpFoldResult>{computeSpan (
151+ rewriter, loc, stridedMetadata.getConstifiedMixedSizes (),
152+ stridedMetadata.getConstifiedMixedStrides ())},
86153 /* strides = */
87- ArrayRef<OpFoldResult>{
88- stridedMetadata.getConstifiedMixedStrides ().back ()}),
154+ ArrayRef<OpFoldResult>{rewriter.getIndexAttr (1 )}),
89155 getValueFromOpFoldResult (rewriter, loc, linearizedIndices));
90156}
91157
0 commit comments