@@ -886,17 +886,31 @@ class RewriteScalarExtractOfTransferRead
886886 SmallVector<Value> newIndices (xferOp.getIndices ().begin (),
887887 xferOp.getIndices ().end ());
888888 for (auto [i, pos] : llvm::enumerate (extractOp.getMixedPosition ())) {
889- assert (isa<Attribute>(pos) && " Unexpected non-constant index" );
890- int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt ();
891889 int64_t idx = newIndices.size () - extractOp.getNumIndices () + i;
892- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
893- rewriter, extractOp.getLoc (),
894- rewriter.getAffineSymbolExpr (0 ) + offset, {newIndices[idx]});
895- if (auto value = dyn_cast<Value>(ofr)) {
890+
891+ // Compute affine expression `newIndices[idx] + pos` where `pos` can be
892+ // either a constant or a value.
893+ OpFoldResult composedIdx;
894+ if (auto attr = dyn_cast<Attribute>(pos)) {
895+ int64_t offset = cast<IntegerAttr>(attr).getInt ();
896+ composedIdx = affine::makeComposedFoldedAffineApply (
897+ rewriter, extractOp.getLoc (),
898+ rewriter.getAffineSymbolExpr (0 ) + offset, {newIndices[idx]});
899+ } else {
900+ Value dynamicOffset = cast<Value>(pos);
901+ AffineExpr sym0, sym1;
902+ bindSymbols (rewriter.getContext (), sym0, sym1);
903+ composedIdx = affine::makeComposedFoldedAffineApply (
904+ rewriter, extractOp.getLoc (), sym0 + sym1,
905+ {newIndices[idx], dynamicOffset});
906+ }
907+
908+ // Update the corresponding index with the folded result.
909+ if (auto value = dyn_cast<Value>(composedIdx)) {
896910 newIndices[idx] = value;
897911 } else {
898912 newIndices[idx] = rewriter.create <arith::ConstantIndexOp>(
899- extractOp.getLoc (), *getConstantIntValue (ofr ));
913+ extractOp.getLoc (), *getConstantIntValue (composedIdx ));
900914 }
901915 }
902916 if (isa<MemRefType>(xferOp.getBase ().getType ())) {
0 commit comments