Skip to content

Commit 20c2cf6

Browse files
committed
bug fix
1 parent 3690c61 commit 20c2cf6

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,9 +1741,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17411741
// 2. Values that are results of the `ForOp`:
17421742
// In this case, we record the index mapping between the `WarpOp` result
17431743
// index and matching `ForOp` result index.
1744+
// Additionally, we keep track of the distributed types for all `ForOp`
1745+
// vector results.
17441746
SmallVector<Value> nonForYieldedValues;
17451747
SmallVector<unsigned> nonForResultIndices;
17461748
llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
1749+
llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
17471750
for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
17481751
// Yielded value is not a result of the forOp.
17491752
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
@@ -1752,8 +1755,15 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17521755
continue;
17531756
}
17541757
OpResult forResult = cast<OpResult>(yieldOperand.get());
1755-
forResultMapping[yieldOperand.getOperandNumber()] =
1756-
forResult.getResultNumber();
1758+
unsigned int forResultNumber = forResult.getResultNumber();
1759+
forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
1760+
// If this `ForOp` result is vector type and it is yielded by the
1761+
// `WarpOp`, we keep track the distributed type for this result.
1762+
if (!isa<VectorType>(forResult.getType()))
1763+
continue;
1764+
VectorType distType = cast<VectorType>(
1765+
warpOp.getResult(yieldOperand.getOperandNumber()).getType());
1766+
forResultDistTypes[forResultNumber] = distType;
17571767
}
17581768

17591769
// Newly created `WarpOp` will yield values in following order:
@@ -1767,8 +1777,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17671777
// Compute the distributed type for this init arg.
17681778
Type distType = initArg.getType();
17691779
if (auto vecType = dyn_cast<VectorType>(distType)) {
1780+
// If the `ForOp` result corresponds to this init arg is already yielded
1781+
// we can get the distributed type from `forResultDistTypes` map.
1782+
// Otherwise, we compute it using distributionMapFn.
17701783
AffineMap map = distributionMapFn(initArg);
1771-
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1784+
distType = forResultDistTypes.count(i)
1785+
? forResultDistTypes[i]
1786+
: getDistributedType(vecType, map, warpOp.getWarpSize());
17721787
}
17731788
newWarpOpDistTypes.push_back(distType);
17741789
}

0 commit comments

Comments
 (0)