Skip to content

Commit 90ef1ab

Browse files
committed
Address feedback
1 parent 784dda1 commit 90ef1ab

File tree

1 file changed

+9
-23
lines changed

1 file changed

+9
-23
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
373373

374374
/// Given a warpOp that contains ops with regions, the corresponding op's
375375
/// "inner" region and the distributionMapFn, get all values used by the op's
376-
/// region that are defined within the warpOp. Return the set of values, their
377-
/// types and their distributed types.
376+
/// region that are defined within the warpOp, but outside the inner region.
377+
/// Return the set of values, their types and their distributed types.
378378
std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
379379
SmallVector<Type>>
380380
getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
@@ -383,7 +383,8 @@ getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
383383
SmallVector<Type> escapingValueTypes;
384384
SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
385385
if (innerRegion.empty())
386-
return {escapingValues, escapingValueTypes, escapingValueDistTypes};
386+
return {std::move(escapingValues), std::move(escapingValueTypes),
387+
std::move(escapingValueDistTypes)};
387388
mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
388389
Operation *parent = operand->get().getParentRegion()->getParentOp();
389390
if (warpOp->isAncestor(parent)) {
@@ -398,7 +399,8 @@ getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
398399
escapingValueDistTypes.push_back(distType);
399400
}
400401
});
401-
return {escapingValues, escapingValueTypes, escapingValueDistTypes};
402+
return {std::move(escapingValues), std::move(escapingValueTypes),
403+
std::move(escapingValueDistTypes)};
402404
}
403405

404406
/// Distribute transfer_write ops based on the affine map returned by
@@ -1998,25 +2000,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
19982000
return failure();
19992001
// Collect Values that come from the `WarpOp` but are outside the `ForOp`.
20002002
// Those Values need to be returned by the new warp op.
2001-
llvm::SmallSetVector<Value, 32> escapingValues;
2002-
SmallVector<Type> escapingValueInputTypes;
2003-
SmallVector<Type> escapingValueDistTypes;
2004-
mlir::visitUsedValuesDefinedAbove(
2005-
forOp.getBodyRegion(), [&](OpOperand *operand) {
2006-
Operation *parent = operand->get().getParentRegion()->getParentOp();
2007-
if (warpOp->isAncestor(parent)) {
2008-
if (!escapingValues.insert(operand->get()))
2009-
return;
2010-
Type distType = operand->get().getType();
2011-
if (auto vecType = dyn_cast<VectorType>(distType)) {
2012-
AffineMap map = distributionMapFn(operand->get());
2013-
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
2014-
}
2015-
escapingValueInputTypes.push_back(operand->get().getType());
2016-
escapingValueDistTypes.push_back(distType);
2017-
}
2018-
});
2019-
2003+
auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2004+
getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2005+
distributionMapFn);
20202006
if (llvm::is_contained(escapingValueDistTypes, Type{}))
20212007
return failure();
20222008
// `WarpOp` can yield two types of values:

0 commit comments

Comments
 (0)