@@ -78,29 +78,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
7878
7979 // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
8080 int64_t scaler = dstBits / srcBits;
81-
82- // If all strides and sizes are constant, we can compute the result
83- // directly without creating the AffineMaxOp.
84- int64_t constResult = 0 ;
85- int64_t constStride = 0 ;
86- int64_t constSize = 0 ;
87- bool isAllConstant = true ;
88- for (unsigned i = 0 ; i < sourceRank; ++i) {
89- if (auto constantStride = getConstantIntValue (strides[i])) {
90- constStride = *constantStride;
91- } else {
92- isAllConstant = false ;
93- break ;
94- }
95- if (auto constantSize = getConstantIntValue (sizes[i])) {
96- constSize = *constantSize;
97- } else {
98- isAllConstant = false ;
99- break ;
100- }
101- constResult = std::max (constResult, constStride * constSize / scaler);
102- }
103-
10481 size_t symbolIndex = 0 ;
10582 SmallVector<Value> values;
10683 SmallVector<AffineExpr> productExpressions;
@@ -129,10 +106,11 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
129106 builder.getContext ());
130107
131108 OpFoldResult linearizedSize;
132- if (isAllConstant) {
133- linearizedSize = builder.getIndexAttr (constResult);
109+ Value totalSize =
110+ builder.createOrFold <affine::AffineMaxOp>(loc, maxMap, values);
111+ if (auto constantSize = getConstantIntValue (totalSize)) {
112+ linearizedSize = builder.getIndexAttr (*constantSize);
134113 } else {
135- Value totalSize = builder.create <affine::AffineMaxOp>(loc, maxMap, values);
136114 linearizedSize = totalSize;
137115 }
138116
0 commit comments