1717#include " mlir/IR/AffineExpr.h"
1818#include " mlir/IR/Attributes.h"
1919#include " mlir/IR/BuiltinTypes.h"
20+ #include " mlir/IR/Value.h"
2021#include " mlir/Interfaces/SideEffectInterfaces.h"
2122#include " mlir/Transforms/RegionUtils.h"
2223#include " llvm/ADT/SetVector.h"
@@ -1039,7 +1040,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
10391040 bool isResultDistributed = distributedResultType.getNumElements () <
10401041 oldCastOp.getResultVectorType ().getNumElements ();
10411042
1042- // If the result is not distributed, source distribted type is the same
1043+ // If the result is not distributed, source distributed type is the same
10431044 // as the source type. If the result is distributed, we need to compute the
10441045 // distributed source type according to following rules:
10451046 // 1. If the source type is yielded from the warp op, we can use the
@@ -1051,7 +1052,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
10511052 // Check if the source is yielded from the warp op.
10521053 gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
10531054 warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1054- auto *it =
1055+ OpOperand *it =
10551056 llvm::find_if (yieldOp->getOpOperands (), [&](OpOperand &operand) {
10561057 return operand.get () == oldCastOp.getSource ();
10571058 });
@@ -2155,7 +2156,9 @@ struct WarpOpMultiReduction : public WarpDistributionPattern {
21552156 // case each lane owns its portion of the result (i.e. result is also
21562157 // distributed).
21572158 // 3. If reduction dim == 1, its a row reduction that require cross lanes
2158- // shuffles. In this case result is not distributed and broadcasted instead.
2159+ // shuffles. In this case, the reduction result is not distributed across
2160+ // lanes. Instead each lane owns a complete copy of the result
2161+ // (broadcasted).
21592162 // TODO: These assumptions are fairly restrictive. For example, source
21602163 // vector can have row distributed layout. Improve support for such cases.
21612164 if (sourceType.getShape ()[1 ] % warpOp.getWarpSize () != 0 )
0 commit comments