@@ -545,12 +545,6 @@ void updateCurrentArgsStatus(ValueRange loopState, const size_t loopStateIdx,
545545 DenseMap<Value, Value> &nextOriginalOperandMap,
546546 DenseMap<Value, Value> &nextOperandOriginalMap) {
547547 Value currentArgs = loopState[loopStateIdx];
548- if (currentArgs.getType () != originalValue.getType ()) {
549- llvm::outs () << loopStateIdx << " ,"
550- << " \n " ;
551- currentArgs.dump ();
552- llvm::llvm_unreachable_internal (" Type not equal." );
553- }
554548 nextAnchorArgs.emplace_back (currentArgs);
555549 nextAnchorArgsIdxMap[currentArgs] = nextAnchorArgs.size () - 1 ;
556550 nextOriginalOperandMap[originalValue] = currentArgs;
@@ -740,6 +734,36 @@ void updateLoopArgsData(Value val, Value originalVal,
740734 originalOperandLoopArgsMap[originalVal] = val;
741735}
742736
737+ void LoopGeneratorImpl::rectifyParallelIndice (
738+ GenerateLoopHelper &loopHelperParam, OpBuilder &b, Location loc) {
739+ MultiReductionCanonicalizer rdCanonicalizer =
740+ getMultiRdCanonicalizers ()[loopHelperParam.groupIdx ];
741+ auto &multireductionOp = rdCanonicalizer.getCandidateOps ()[0 ];
742+ SmallVector<int64_t , 4 > &reductionAxis = rdCanonicalizer.getReductionAxis ();
743+
744+ // rectify indice of read from source operand
745+ auto sourceReadOp =
746+ multireductionOp.getSource ().getDefiningOp <vector::TransferReadOp>();
747+ if (!sourceReadOp)
748+ return ;
749+
750+ AffineExpr outterParallel, innerParallel;
751+ bindDims (multireductionOp->getContext (), outterParallel, innerParallel);
752+
753+ Value op =
754+ loopHelperParam.inductionVars [loopHelperParam.inductionVars .size () -
755+ reductionAxis.size () - 2 ];
756+ Value ip =
757+ loopHelperParam.inductionVars [loopHelperParam.inductionVars .size () -
758+ reductionAxis.size () - 1 ];
759+ Value newIndice = b.createOrFold <affine::AffineApplyOp>(
760+ loc, (outterParallel + innerParallel), ValueRange{op, ip});
761+ int parallelSize = rdCanonicalizer.getParallelAxis ().size ();
762+ int readIndiceOffset =
763+ 1 + rdCanonicalizer.getParallelAxis ()[parallelSize - 1 ];
764+ sourceReadOp->setOperand (readIndiceOffset, newIndice);
765+ }
766+
743767scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop (
744768 OpBuilder &opBuilder, const size_t reductionIdx,
745769 GenerateLoopHelper &loopHelperParam) {
@@ -755,18 +779,22 @@ scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop(
755779
756780 const auto loc = multireductionOp->getLoc ();
757781 SmallVector<int64_t , 4 > &reductionAxis = rdCanonicalizer.getReductionAxis ();
758- bool lastDimReduction = rdCanonicalizer.hasLastDimReduction ();
759782 VectorType vectorType = rdCanonicalizer.getSourceType ();
760- const int loopStep =
761- getVectorBasedFusion ().getGroupMaxSteps ()[loopHelperParam.groupIdx ];
783+ auto tpHelper = fusionStrategy.getTypeHelper ();
784+
785+ int loopStep = tpHelper.generateValidSteps (
786+ fusionStrategy.getTypeHelper ().getDataTypeMAXSIMDLength (vectorType),
787+ vectorType, vectorType.getShape ()[reductionAxis[reductionIdx]]);
788+ bool isLastDimReduction = rdCanonicalizer.getHasLastDimReduction ();
789+ loopStep = (reductionIdx == reductionAxis.size () - 1 && isLastDimReduction)
790+ ? loopStep
791+ : 1 ;
792+
762793 func::FuncOp func = fusionStrategy.getFunction ();
763794 IRRewriter rewriterOfFunc (func);
764795
765796 Value zero = makeIndexArithConstantOp (opBuilder, loc, 0 );
766- Value forSteps = makeIndexArithConstantOp (
767- opBuilder, loc,
768- (reductionIdx == reductionAxis.size () - 1 && lastDimReduction) ? loopStep
769- : 1 );
797+ Value forSteps = makeIndexArithConstantOp (opBuilder, loc, loopStep);
770798 Value numIter = makeIndexArithConstantOp (
771799 opBuilder, loc, vectorType.getShape ()[reductionAxis[reductionIdx]]);
772800 scf::ForOp forOp = opBuilder.create <scf::ForOp>(
@@ -868,9 +896,12 @@ scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop(
868896 }
869897
870898 rewriteOperationAsVectorize (b, loopHelperParam.groupIdx ,
871- &movingOperation);
899+ &movingOperation,
900+ isLastDimReduction ? loopStep : 0 );
872901 loopHelperParam.loopIterArgs = loopState;
873902 moveOperationsToCurrentForBody (b, movingOperation, loopHelperParam);
903+ if (isLastDimReduction)
904+ rectifyParallelIndice (loopHelperParam, b, loc);
874905 loopHelperParam.movedOps = &movingOperation;
875906 loopHelperParam.candidateOps = &opQueue;
876907
@@ -1058,11 +1089,16 @@ scf::ForOp LoopGeneratorImpl::parallelAxisGenerateForLoop(
10581089 // get accumualte value
10591090 Attribute initValueAttr;
10601091 getReductionInitAttr (multiReductionOp, initValueAttr);
1061-
1092+ SmallVector<int64_t , 4 > &reductionAxis =
1093+ rdCanonicalizer.getReductionAxis ();
1094+ TypeHelper tpHelper = fusionStrategy.getTypeHelper ();
1095+ int loopStep = tpHelper.generateValidSteps (
1096+ tpHelper.getDataTypeMAXSIMDLength (vectorType), vectorType,
1097+ vectorType.getShape ()[reductionAxis[reductionAxis.size () - 1 ]]);
10621098 auto accVal = b.create <arith::ConstantOp>(
10631099 loc, DenseElementsAttr::get (
10641100 fusionStrategy.getTypeHelper ().getVectorzedType (
1065- multiReductionOp, dimSize ),
1101+ multiReductionOp, loopStep ),
10661102 {initValueAttr}));
10671103
10681104 // put accumulte val at first for loop args
@@ -1247,14 +1283,14 @@ void LoopGeneratorImpl::rearrageMultiReductionIR(
12471283 DenseMap<size_t , size_t > varLoopIdxMap;
12481284 VectorType groupVector =
12491285 getVectorBasedFusion ().getGroupBiggestRankVectorType ()[grpIdx];
1250- for (size_t i = 0 ; i < parallelAxis.size (); i++) {
1286+ for (size_t i = 0 ; i < parallelAxis.size (); i++)
12511287 varLoopIdxMap[parallelAxis[i]] = i;
1252- }
1288+
12531289 size_t offset = rdCanonicalizer.hasLastDimReduction () ? 1 : 0 ;
12541290 for (size_t i = parallelAxis.size () + offset;
1255- i < groupVector.getRank () + offset; i++) {
1291+ i < groupVector.getRank () + offset; i++)
12561292 varLoopIdxMap[reductionAxis[i - parallelAxis.size () - offset]] = i;
1257- }
1293+
12581294 while (!tmpSourceQ.empty ()) {
12591295 auto *curOp = tmpSourceQ.front ();
12601296 tmpSourceQ.pop ();
@@ -2313,7 +2349,8 @@ void ForLoopGenerator::createNewConstantOp(
23132349
23142350// / Rewrite the operations in the group to vectorized form.
23152351void ForLoopGenerator::rewriteOperationAsVectorize (
2316- OpBuilder &rewriter, size_t groupId, const std::queue<Operation *> *queue) {
2352+ OpBuilder &rewriter, size_t groupId, const std::queue<Operation *> *queue,
2353+ const size_t vectorizeStep) {
23172354 const std::queue<Operation *> groupOps =
23182355 !queue ? getVectorBasedFusion ().getOpGroups ()[groupId] : *queue;
23192356
@@ -2322,7 +2359,9 @@ void ForLoopGenerator::rewriteOperationAsVectorize(
23222359 DenseMap<Operation *, AffineMap> &opPermuationMap =
23232360 getVectorBasedFusion ().getOpPermuationMap ();
23242361 std::queue<Operation *> transformQueue (groupOps);
2325- size_t groupSteps = getVectorBasedFusion ().getGroupMaxSteps ()[groupId];
2362+ size_t groupSteps = vectorizeStep == 0
2363+ ? getVectorBasedFusion ().getGroupMaxSteps ()[groupId]
2364+ : vectorizeStep;
23262365
23272366 while (!transformQueue.empty ()) {
23282367 Operation *op = transformQueue.front ();
0 commit comments