@@ -1554,36 +1554,22 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
15541554 llvm::SmallSetVector<Value, 32 > escapingValues;
15551555 SmallVector<Type> inputTypes;
15561556 SmallVector<Type> distTypes;
1557- auto collectEscapingValues = [&](Value value) {
1558- if (!escapingValues.insert (value))
1559- return ;
1560- Type distType = value.getType ();
1561- if (auto vecType = dyn_cast<VectorType>(distType)) {
1562- AffineMap map = distributionMapFn (value);
1563- distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1564- }
1565- inputTypes.push_back (value.getType ());
1566- distTypes.push_back (distType);
1567- };
1568-
15691557 mlir::visitUsedValuesDefinedAbove (
15701558 forOp.getBodyRegion (), [&](OpOperand *operand) {
15711559 Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
15721560 if (warpOp->isAncestor (parent)) {
1573- collectEscapingValues (operand->get ());
1561+ if (!escapingValues.insert (operand->get ()))
1562+ return ;
1563+ Type distType = operand->get ().getType ();
1564+ if (auto vecType = dyn_cast<VectorType>(distType)) {
1565+ AffineMap map = distributionMapFn (operand->get ());
1566+ distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1567+ }
1568+ inputTypes.push_back (operand->get ().getType ());
1569+ distTypes.push_back (distType);
15741570 }
15751571 });
15761572
1577- // Any forOp result that is not already yielded by the warpOp
1578- // region is also considered escaping and must be returned by the
1579- // original warpOp.
1580- for (OpResult forResult : forOp.getResults ()) {
1581- // Check if this forResult is already yielded by the yield op.
1582- if (llvm::is_contained (yield->getOperands (), forResult))
1583- continue ;
1584- collectEscapingValues (forResult);
1585- }
1586-
15871573 if (llvm::is_contained (distTypes, Type{}))
15881574 return failure ();
15891575
@@ -1623,12 +1609,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
16231609 forOp.getResultTypes ().end ());
16241610 llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
16251611 for (auto [i, retIdx] : llvm::enumerate (newRetIndices)) {
1626- auto newWarpResult = newWarpOp.getResult (retIdx);
1627- // Unused forOp results yielded by the warpOp region are already included
1628- // in the new ForOp.
1629- if (llvm::is_contained (newOperands, newWarpResult))
1630- continue ;
1631- warpInput.push_back (newWarpResult);
1612+ warpInput.push_back (newWarpOp.getResult (retIdx));
16321613 argIndexMapping[escapingValues[i]] = warpInputType.size ();
16331614 warpInputType.push_back (inputTypes[i]);
16341615 }
0 commit comments