@@ -1741,9 +1741,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17411741 // 2. Values that are results of the `ForOp`:
17421742 // In this case, we record the index mapping between the `WarpOp` result
17431743 // index and matching `ForOp` result index.
1744+ // Additionally, we keep track of the distributed types for all `ForOp`
1745+ // vector results.
17441746 SmallVector<Value> nonForYieldedValues;
17451747 SmallVector<unsigned > nonForResultIndices;
17461748 llvm::SmallDenseMap<unsigned , unsigned > forResultMapping;
1749+ llvm::SmallDenseMap<unsigned , VectorType> forResultDistTypes;
17471750 for (OpOperand &yieldOperand : warpOpYield->getOpOperands ()) {
17481751 // Yielded value is not a result of the forOp.
17491752 if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ()) {
@@ -1752,8 +1755,15 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17521755 continue ;
17531756 }
17541757 OpResult forResult = cast<OpResult>(yieldOperand.get ());
1755- forResultMapping[yieldOperand.getOperandNumber ()] =
1756- forResult.getResultNumber ();
1758+ unsigned int forResultNumber = forResult.getResultNumber ();
1759+ forResultMapping[yieldOperand.getOperandNumber ()] = forResultNumber;
1760+ // If this `ForOp` result is vector type and it is yielded by the
1761+ // `WarpOp`, we keep track the distributed type for this result.
1762+ if (!isa<VectorType>(forResult.getType ()))
1763+ continue ;
1764+ VectorType distType = cast<VectorType>(
1765+ warpOp.getResult (yieldOperand.getOperandNumber ()).getType ());
1766+ forResultDistTypes[forResultNumber] = distType;
17571767 }
17581768
17591769 // Newly created `WarpOp` will yield values in following order:
@@ -1767,8 +1777,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17671777 // Compute the distributed type for this init arg.
17681778 Type distType = initArg.getType ();
17691779 if (auto vecType = dyn_cast<VectorType>(distType)) {
1780+ // If the `ForOp` result corresponds to this init arg is already yielded
1781+ // we can get the distributed type from `forResultDistTypes` map.
1782+ // Otherwise, we compute it using distributionMapFn.
17701783 AffineMap map = distributionMapFn (initArg);
1771- distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1784+ distType = forResultDistTypes.count (i)
1785+ ? forResultDistTypes[i]
1786+ : getDistributedType (vecType, map, warpOp.getWarpSize ());
17721787 }
17731788 newWarpOpDistTypes.push_back (distType);
17741789 }
0 commit comments