1818#include " mlir/Dialect/Affine/Analysis/NestedMatcher.h"
1919#include " mlir/Dialect/Affine/IR/AffineOps.h"
2020#include " mlir/Dialect/Affine/IR/AffineValueMap.h"
21+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
2122#include " llvm/Support/MathExtras.h"
2223
2324#include " llvm/ADT/DenseSet.h"
@@ -84,6 +85,67 @@ void mlir::affine::getTripCountMapAndOperands(
8485 tripCountValueMap.getOperands ().end ());
8586}
8687
88+ // / Replace thread_id with its maximum value, if `replaceWithZero` is true,
89+ // / thread_id will be replaced by its minimum value 0.
90+ static void replaceGPUOperands (AffineForOp forOp,
91+ SmallVectorImpl<Value> &operands,
92+ SmallVectorImpl<AffineExpr> &symReplacements,
93+ unsigned numDim, bool replaceWithZero = false ) {
94+ auto launchOp = forOp->getParentOfType <gpu::LaunchOp>();
95+ if (!launchOp)
96+ return ;
97+
98+ // `b` is only used to create `AffineExpr`.
99+ Builder b (forOp.getContext ());
100+ unsigned idx = 0 ;
101+
102+ for (unsigned i = numDim, e = operands.size (); i < e; ++i) {
103+ Value operand = operands[i];
104+ if (Value blockSize = launchOp.getBlockSizeOnAxis (operand)) {
105+ operands[i] = blockSize;
106+ if (!replaceWithZero)
107+ symReplacements.push_back (b.getAffineSymbolExpr (idx++) - 1 );
108+ else
109+ symReplacements.push_back (b.getAffineConstantExpr (0 ));
110+ continue ;
111+ }
112+
113+ Operation *defOp = operand.getDefiningOp ();
114+ if (!defOp) {
115+ ++idx;
116+ continue ;
117+ }
118+
119+ if (auto threadIdOp = mlir::dyn_cast<gpu::ThreadIdOp>(defOp)) {
120+ gpu::Dimension dimension = threadIdOp.getDimension ();
121+ operands[i] = launchOp.getBlockSizeOnAxis (dimension);
122+ if (!replaceWithZero)
123+ symReplacements.push_back (b.getAffineSymbolExpr (idx++) - 1 );
124+ else
125+ symReplacements.push_back (b.getAffineConstantExpr (0 ));
126+ continue ;
127+ }
128+ ++idx;
129+ }
130+ }
131+
132+ // / Take the min if all trip counts are constant.
133+ static std::optional<uint64_t >
134+ getConstantTripCountFromAffineMap (AffineMap map) {
135+ std::optional<uint64_t > tripCount;
136+ for (auto resultExpr : map.getResults ()) {
137+ auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr);
138+ if (!constExpr)
139+ return std::nullopt ;
140+ if (tripCount.has_value ())
141+ tripCount =
142+ std::min (*tripCount, static_cast <uint64_t >(constExpr.getValue ()));
143+ else
144+ tripCount = constExpr.getValue ();
145+ }
146+ return tripCount;
147+ }
148+
87149// / Returns the trip count of the loop if it's a constant, std::nullopt
88150// / otherwise. This method uses affine expression analysis (in turn using
89151// / getTripCount) and is able to determine constant trip count in non-trivial
@@ -95,20 +157,34 @@ std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
95157
96158 if (!map)
97159 return std::nullopt ;
160+ SmallVector<AffineExpr, 4 > symReplacements;
161+ replaceGPUOperands (forOp, operands, symReplacements, map.getNumDims ());
162+ map = map.replaceDimsAndSymbols ({}, symReplacements, map.getNumDims (),
163+ map.getNumSymbols ());
164+ affine::AffineValueMap valueMap (map, operands);
165+ (void )valueMap.canonicalize ();
166+ map = valueMap.getAffineMap ();
167+ return getConstantTripCountFromAffineMap (map);
168+ }
98169
99- // Take the min if all trip counts are constant.
100- std::optional<uint64_t > tripCount;
101- for (auto resultExpr : map.getResults ()) {
102- if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
103- if (tripCount.has_value ())
104- tripCount =
105- std::min (*tripCount, static_cast <uint64_t >(constExpr.getValue ()));
106- else
107- tripCount = constExpr.getValue ();
108- } else
109- return std::nullopt ;
110- }
111- return tripCount;
170+ // / In some scenarios, such as GPU, the number of trip of each thread in the
171+ // / loop is inconsistent. This function returns the maximum number of trip.
172+ std::optional<uint64_t >
173+ mlir::affine::getMaxConstantTripCount (AffineForOp forOp) {
174+ SmallVector<Value, 4 > operands;
175+ AffineMap map;
176+ getTripCountMapAndOperands (forOp, &map, &operands);
177+
178+ if (!map)
179+ return std::nullopt ;
180+ SmallVector<AffineExpr, 4 > symReplacements;
181+ replaceGPUOperands (forOp, operands, symReplacements, map.getNumDims (), true );
182+ map = map.replaceDimsAndSymbols ({}, symReplacements, map.getNumDims (),
183+ map.getNumSymbols ());
184+ affine::AffineValueMap valueMap (map, operands);
185+ (void )valueMap.canonicalize ();
186+ map = valueMap.getAffineMap ();
187+ return getConstantTripCountFromAffineMap (map);
112188}
113189
114190// / Returns the greatest known integral divisor of the trip count. Affine
@@ -121,7 +197,13 @@ uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) {
121197
122198 if (!map)
123199 return 1 ;
124-
200+ SmallVector<AffineExpr, 4 > symReplacements;
201+ replaceGPUOperands (forOp, operands, symReplacements, map.getNumDims ());
202+ map = map.replaceDimsAndSymbols ({}, symReplacements, map.getNumDims (),
203+ map.getNumSymbols ());
204+ affine::AffineValueMap valueMap (map, operands);
205+ (void )valueMap.canonicalize ();
206+ map = valueMap.getAffineMap ();
125207 // The largest divisor of the trip count is the GCD of the individual largest
126208 // divisors.
127209 assert (map.getNumResults () >= 1 && " expected one or more results" );
0 commit comments