@@ -735,33 +735,37 @@ void updateLoopArgsData(Value val, Value originalVal,
735735}
736736
737737void LoopGeneratorImpl::rectifyParallelIndice (
738- GenerateLoopHelper &loopHelperParam, OpBuilder &b, Location loc) {
738+ GenerateLoopHelper &loopHelperParam, Location loc) {
739739 MultiReductionCanonicalizer rdCanonicalizer =
740740 getMultiRdCanonicalizers ()[loopHelperParam.groupIdx ];
741741 auto &multireductionOp = rdCanonicalizer.getCandidateOps ()[0 ];
742742 SmallVector<int64_t , 4 > &reductionAxis = rdCanonicalizer.getReductionAxis ();
743743
744744 // rectify indice of read from source operand
745- auto sourceReadOp =
746- multireductionOp.getSource ().getDefiningOp <vector::TransferReadOp>();
747- if (!sourceReadOp)
748- return ;
749-
750- AffineExpr outterParallel, innerParallel;
751- bindDims (multireductionOp->getContext (), outterParallel, innerParallel);
752-
753- Value op =
754- loopHelperParam.inductionVars [loopHelperParam.inductionVars .size () -
755- reductionAxis.size () - 2 ];
756- Value ip =
757- loopHelperParam.inductionVars [loopHelperParam.inductionVars .size () -
758- reductionAxis.size () - 1 ];
759- Value newIndice = b.createOrFold <affine::AffineApplyOp>(
760- loc, (outterParallel + innerParallel), ValueRange{op, ip});
761- int parallelSize = rdCanonicalizer.getParallelAxis ().size ();
762- int readIndiceOffset =
763- 1 + rdCanonicalizer.getParallelAxis ()[parallelSize - 1 ];
764- sourceReadOp->setOperand (readIndiceOffset, newIndice);
745+ std::queue<Operation *> candidateOps;
746+ getSameBlockTargetOp<vector::TransferReadOp>(
747+ multireductionOp.getSource ().getDefiningOp (), candidateOps);
748+ while (not candidateOps.empty ()) {
749+ auto sourceReadOp = candidateOps.front ();
750+ candidateOps.pop ();
751+ IRRewriter rewriter (sourceReadOp);
752+ rewriter.setInsertionPoint (sourceReadOp);
753+ AffineExpr outterParallel, innerParallel;
754+ bindDims (multireductionOp->getContext (), outterParallel, innerParallel);
755+
756+ Value op =
757+ loopHelperParam.inductionVars [loopHelperParam.inductionVars .size () -
758+ reductionAxis.size () - 2 ];
759+ Value ip =
760+ loopHelperParam.inductionVars [loopHelperParam.inductionVars .size () -
761+ reductionAxis.size () - 1 ];
762+ Value newIndice = rewriter.createOrFold <affine::AffineApplyOp>(
763+ loc, (outterParallel + innerParallel), ValueRange{op, ip});
764+ int parallelSize = rdCanonicalizer.getParallelAxis ().size ();
765+ int readIndiceOffset =
766+ 1 + rdCanonicalizer.getParallelAxis ()[parallelSize - 1 ];
767+ sourceReadOp->setOperand (readIndiceOffset, newIndice);
768+ }
765769}
766770
767771scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop (
@@ -901,7 +905,7 @@ scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop(
901905 loopHelperParam.loopIterArgs = loopState;
902906 moveOperationsToCurrentForBody (b, movingOperation, loopHelperParam);
903907 if (isLastDimReduction)
904- rectifyParallelIndice (loopHelperParam, b, loc);
908+ rectifyParallelIndice (loopHelperParam, loc);
905909 loopHelperParam.movedOps = &movingOperation;
906910 loopHelperParam.candidateOps = &opQueue;
907911
@@ -2768,7 +2772,7 @@ void GroupOperationFusionImpl::broadcastFromElements(Operation *op,
27682772 DenseElementsAttr::get (dataType, constantOp.getValue ()),
27692773 newOperandType);
27702774 if (failed (res))
2771- llvm::llvm_unreachable_internal (" Wrong to create constant op." );
2775+ llvm_unreachable (" Wrong to create constant op." );
27722776 removeOpInCurrentGroups (grpIdx, op, res.value ().getDefiningOp ());
27732777
27742778 } else {
0 commit comments