@@ -606,7 +606,7 @@ struct MmaSyncBuilder {
606606 // / IndexCalculator callback.
607607 SmallVector<Value> buildMemRefLoads (OpBuilder &b, Location loc,
608608 OpFoldResult laneId, Value memref,
609- IndexCalculator indexFn);
609+ const IndexCalculator & indexFn);
610610
611611 // / Perform a distributed load of a vector operand of `vectorShape` for a
612612 // / particular MMA instruction whose `(row, col)` indices are specified via
@@ -625,7 +625,7 @@ struct MmaSyncBuilder {
625625 SmallVector<Operation *> buildMemRefStores (OpBuilder &b, Location loc,
626626 ValueRange toStore,
627627 OpFoldResult laneId, Value memref,
628- IndexCalculator indexFn);
628+ const IndexCalculator & indexFn);
629629
630630 // / Perform a distributed store of a vector operand of `vectorShape` for a
631631 // / particular MMA instruction whose `(row, col)` indices are specified via
@@ -660,10 +660,10 @@ static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
660660 }
661661}
662662
663- SmallVector<Value> MmaSyncBuilder::buildMemRefLoads (OpBuilder &b, Location loc,
664- OpFoldResult laneId ,
665- Value memref,
666- IndexCalculator indexFn) {
663+ SmallVector<Value>
664+ MmaSyncBuilder::buildMemRefLoads (OpBuilder &b, Location loc ,
665+ OpFoldResult laneId, Value memref,
666+ const IndexCalculator & indexFn) {
667667 auto aff = [&](AffineExpr e) {
668668 return affine::makeComposedFoldedAffineApply (b, loc, e, laneId);
669669 };
@@ -681,7 +681,7 @@ SmallVector<Value> MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
681681Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand (
682682 OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
683683 IndexCalculator indexFn, ArrayRef<int64_t > vectorShape) {
684- auto loads = buildMemRefLoads (b, loc, laneId, memref, indexFn);
684+ auto loads = buildMemRefLoads (b, loc, laneId, memref, std::move ( indexFn) );
685685
686686 Type elementType = getElementTypeOrSelf (memref.getType ());
687687 auto vt = VectorType::get (vectorShape, elementType);
@@ -700,10 +700,9 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
700700 return res;
701701}
702702
703- SmallVector<Operation *>
704- MmaSyncBuilder::buildMemRefStores (OpBuilder &b, Location loc,
705- ValueRange toStore, OpFoldResult laneId,
706- Value memref, IndexCalculator indexFn) {
703+ SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores (
704+ OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
705+ Value memref, const IndexCalculator &indexFn) {
707706 auto aff = [&](AffineExpr e) {
708707 return affine::makeComposedFoldedAffineApply (b, loc, e, laneId);
709708 };
@@ -734,7 +733,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
734733 [&](Value v, int64_t linearIdx, ArrayRef<int64_t > indices) {
735734 toStore.push_back (v);
736735 });
737- return buildMemRefStores (b, loc, toStore, laneId, memref, indexFn);
736+ return buildMemRefStores (b, loc, toStore, laneId, memref, std::move ( indexFn) );
738737}
739738
740739static std::tuple<SmallVector<int64_t >, SmallVector<int64_t >,
0 commit comments