@@ -47,12 +47,57 @@ using namespace mlir::transform::gpu;
4747#define LDBG (X ) LLVM_DEBUG(DBGS() << (X) << " \n " )
4848#define DBGS_ALIAS () (llvm::dbgs() << ' [' << DEBUG_TYPE_ALIAS << " ] " )
4949
50+ // / Build predicates to filter execution by only the activeIds. Along each
51+ // / dimension, 3 cases appear:
52+ // / 1. activeMappingSize > availableMappingSize: this is an unsupported case
53+ // / as this requires additional looping. An error message is produced to
54+ // / advise the user to tile more or to use more threads.
55+ // / 2. activeMappingSize == availableMappingSize: no predication is needed.
56+ // / 3. activeMappingSize < availableMappingSize: only a subset of threads
57+ // / should be active and we produce the boolean `id < activeMappingSize`
58+ // / for further use in building predicated execution.
59+ static FailureOr<SmallVector<Value>>
60+ buildPredicates (RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds,
61+ ArrayRef<int64_t > activeMappingSizes,
62+ ArrayRef<int64_t > availableMappingSizes,
63+ std::string &errorMsg) {
64+ // clang-format off
65+ LLVM_DEBUG (
66+ llvm::interleaveComma (
67+ activeMappingSizes, DBGS () << " ----activeMappingSizes: " );
68+ DBGS () << " \n " ;
69+ llvm::interleaveComma (
70+ availableMappingSizes, DBGS () << " ----availableMappingSizes: " );
71+ DBGS () << " \n " ;);
72+ // clang-format on
73+
74+ SmallVector<Value> predicateOps;
75+ for (auto [activeId, activeMappingSize, availableMappingSize] :
76+ llvm::zip_equal (activeIds, activeMappingSizes, availableMappingSizes)) {
77+ if (activeMappingSize > availableMappingSize) {
78+ errorMsg = " Trying to map to fewer GPU threads than loop iterations but "
79+ " overprovisioning is not yet supported. Try additional tiling "
80+ " before mapping or map to more threads." ;
81+ return failure ();
82+ }
83+ if (activeMappingSize == availableMappingSize)
84+ continue ;
85+ Value idx = rewriter.create <arith::ConstantIndexOp>(loc, activeMappingSize);
86+ Value pred = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
87+ activeId, idx);
88+ predicateOps.push_back (pred);
89+ }
90+ return predicateOps;
91+ }
92+
5093// / Return a flattened thread id for the workgroup with given sizes.
5194template <typename ThreadOrBlockIdOp>
5295static Value buildLinearId (RewriterBase &rewriter, Location loc,
5396 ArrayRef<OpFoldResult> originalBasisOfr) {
54- LLVM_DEBUG (DBGS () << " ----buildLinearId with originalBasisOfr: "
55- << llvm::interleaved (originalBasisOfr) << " \n " );
97+ LLVM_DEBUG (llvm::interleaveComma (
98+ originalBasisOfr,
99+ DBGS () << " ----buildLinearId with originalBasisOfr: " );
100+ llvm::dbgs () << " \n " );
56101 assert (originalBasisOfr.size () == 3 && " expected 3 sizes" );
57102 IndexType indexType = rewriter.getIndexType ();
58103 AffineExpr tx, ty, tz, bdx, bdy;
@@ -79,44 +124,43 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
79124 auto res = [multiplicity](RewriterBase &rewriter, Location loc,
80125 ArrayRef<int64_t > forallMappingSizes,
81126 ArrayRef<int64_t > originalBasis) {
127+ // 1. Compute linearId.
82128 SmallVector<OpFoldResult> originalBasisOfr =
83129 getAsIndexOpFoldResult (rewriter.getContext (), originalBasis);
84- OpFoldResult linearId =
130+ Value physicalLinearId =
85131 buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
132+
133+ // 2. Compute scaledLinearId.
134+ AffineExpr d0 = getAffineDimExpr (0 , rewriter.getContext ());
135+ OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply (
136+ rewriter, loc, d0.floorDiv (multiplicity), {physicalLinearId});
137+
138+ // 3. Compute remapped indices.
139+ SmallVector<Value> ids;
86140 // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
87141 // "row-major" order.
88142 SmallVector<int64_t > reverseBasisSizes (llvm::reverse (forallMappingSizes));
89143 SmallVector<int64_t > strides = computeStrides (reverseBasisSizes);
90- AffineExpr d0 = getAffineDimExpr (0 , rewriter.getContext ());
91- OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply (
92- rewriter, loc, d0.floorDiv (multiplicity), {linearId});
93144 SmallVector<AffineExpr> delinearizingExprs = delinearize (d0, strides);
94- SmallVector<Value> ids;
95145 // Reverse back to be in [0 .. n] order.
96146 for (AffineExpr e : llvm::reverse (delinearizingExprs)) {
97147 ids.push_back (
98148 affine::makeComposedAffineApply (rewriter, loc, e, {scaledLinearId}));
99149 }
100150
101- LLVM_DEBUG (DBGS () << " --delinearization basis: "
102- << llvm::interleaved (reverseBasisSizes) << " \n " ;
103- DBGS () << " --delinearization strides: "
104- << llvm::interleaved (strides) << " \n " ;
105- DBGS () << " --delinearization exprs: "
106- << llvm::interleaved (delinearizingExprs) << " \n " ;
107- DBGS () << " --ids: " << llvm::interleaved (ids) << " \n " );
108-
109- // Return n-D ids for indexing and 1-D size + id for predicate generation.
110- return IdBuilderResult{
111- /* mappingIdOps=*/ ids,
112- /* availableMappingSizes=*/
113- SmallVector<int64_t >{computeProduct (originalBasis)},
114- // `forallMappingSizes` iterate in the scaled basis, they need to be
115- // scaled back into the original basis to provide tight
116- // activeMappingSizes quantities for predication.
117- /* activeMappingSizes=*/
118- SmallVector<int64_t >{computeProduct (forallMappingSizes) * multiplicity},
119- /* activeIdOps=*/ SmallVector<Value>{cast<Value>(linearId)}};
151+ // 4. Handle predicates using physicalLinearId.
152+ std::string errorMsg;
153+ SmallVector<Value> predicateOps;
154+ FailureOr<SmallVector<Value>> maybePredicateOps =
155+ buildPredicates (rewriter, loc, physicalLinearId,
156+ computeProduct (forallMappingSizes) * multiplicity,
157+ computeProduct (originalBasis), errorMsg);
158+ if (succeeded (maybePredicateOps))
159+ predicateOps = *maybePredicateOps;
160+
161+ return IdBuilderResult{/* errorMsg=*/ errorMsg,
162+ /* mappingIdOps=*/ ids,
163+ /* predicateOps=*/ predicateOps};
120164 };
121165
122166 return res;
@@ -143,71 +187,65 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
143187 // In the 3-D mapping case, unscale the first dimension by the multiplicity.
144188 SmallVector<int64_t > forallMappingSizeInOriginalBasis (forallMappingSizes);
145189 forallMappingSizeInOriginalBasis[0 ] *= multiplicity;
146- return IdBuilderResult{
147- /* mappingIdOps=*/ scaledIds,
148- /* availableMappingSizes=*/ SmallVector<int64_t >{originalBasis},
149- // `forallMappingSizes` iterate in the scaled basis, they need to be
150- // scaled back into the original basis to provide tight
151- // activeMappingSizes quantities for predication.
152- /* activeMappingSizes=*/
153- SmallVector<int64_t >{forallMappingSizeInOriginalBasis},
154- /* activeIdOps=*/ ids};
190+
191+ std::string errorMsg;
192+ SmallVector<Value> predicateOps;
193+ FailureOr<SmallVector<Value>> maybePredicateOps =
194+ buildPredicates (rewriter, loc, ids, forallMappingSizeInOriginalBasis,
195+ originalBasis, errorMsg);
196+ if (succeeded (maybePredicateOps))
197+ predicateOps = *maybePredicateOps;
198+
199+ return IdBuilderResult{/* errorMsg=*/ errorMsg,
200+ /* mappingIdOps=*/ scaledIds,
201+ /* predicateOps=*/ predicateOps};
155202 };
156203 return res;
157204}
158205
159206// / Create a lane id builder that takes the `originalBasis` and decompose
160207// / it in the basis of `forallMappingSizes`. The linear id builder returns an
161208// / n-D vector of ids for indexing and 1-D size + id for predicate generation.
162- static GpuIdBuilderFnType laneIdBuilderFn (int64_t periodicity) {
163- auto res = [periodicity](RewriterBase &rewriter, Location loc,
164- ArrayRef<int64_t > forallMappingSizes,
165- ArrayRef<int64_t > originalBasis) {
209+ static GpuIdBuilderFnType laneIdBuilderFn (int64_t warpSize) {
210+ auto res = [warpSize](RewriterBase &rewriter, Location loc,
211+ ArrayRef<int64_t > forallMappingSizes,
212+ ArrayRef<int64_t > originalBasis) {
213+ // 1. Compute linearId.
166214 SmallVector<OpFoldResult> originalBasisOfr =
167215 getAsIndexOpFoldResult (rewriter.getContext (), originalBasis);
168- OpFoldResult linearId =
216+ Value physicalLinearId =
169217 buildLinearId<ThreadIdOp>(rewriter, loc, originalBasisOfr);
218+
219+ // 2. Compute laneId.
170220 AffineExpr d0 = getAffineDimExpr (0 , rewriter.getContext ());
171- linearId = affine::makeComposedFoldedAffineApply (
172- rewriter, loc, d0 % periodicity , {linearId });
221+ OpFoldResult laneId = affine::makeComposedFoldedAffineApply (
222+ rewriter, loc, d0 % warpSize , {physicalLinearId });
173223
224+ // 3. Compute remapped indices.
225+ SmallVector<Value> ids;
174226 // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
175227 // "row-major" order.
176228 SmallVector<int64_t > reverseBasisSizes (llvm::reverse (forallMappingSizes));
177229 SmallVector<int64_t > strides = computeStrides (reverseBasisSizes);
178230 SmallVector<AffineExpr> delinearizingExprs = delinearize (d0, strides);
179- SmallVector<Value> ids;
180231 // Reverse back to be in [0 .. n] order.
181232 for (AffineExpr e : llvm::reverse (delinearizingExprs)) {
182233 ids.push_back (
183- affine::makeComposedAffineApply (rewriter, loc, e, {linearId }));
234+ affine::makeComposedAffineApply (rewriter, loc, e, {laneId }));
184235 }
185236
186- // clang-format off
187- LLVM_DEBUG (llvm::interleaveComma (reverseBasisSizes,
188- DBGS () << " --delinearization basis: " );
189- llvm::dbgs () << " \n " ;
190- llvm::interleaveComma (strides,
191- DBGS () << " --delinearization strides: " );
192- llvm::dbgs () << " \n " ;
193- llvm::interleaveComma (delinearizingExprs,
194- DBGS () << " --delinearization exprs: " );
195- llvm::dbgs () << " \n " ;
196- llvm::interleaveComma (ids, DBGS () << " --ids: " );
197- llvm::dbgs () << " \n " ;);
198- // clang-format on
199-
200- // Return n-D ids for indexing and 1-D size + id for predicate generation.
201- return IdBuilderResult{
202- /* mappingIdOps=*/ ids,
203- /* availableMappingSizes=*/
204- SmallVector<int64_t >{computeProduct (originalBasis)},
205- // `forallMappingSizes` iterate in the scaled basis, they need to be
206- // scaled back into the original basis to provide tight
207- // activeMappingSizes quantities for predication.
208- /* activeMappingSizes=*/
209- SmallVector<int64_t >{computeProduct (forallMappingSizes)},
210- /* activeIdOps=*/ SmallVector<Value>{linearId.get <Value>()}};
237+ // 4. Handle predicates using laneId.
238+ std::string errorMsg;
239+ SmallVector<Value> predicateOps;
240+ FailureOr<SmallVector<Value>> maybePredicateOps = buildPredicates (
241+ rewriter, loc, cast<Value>(laneId), computeProduct (forallMappingSizes),
242+ computeProduct (originalBasis), errorMsg);
243+ if (succeeded (maybePredicateOps))
244+ predicateOps = *maybePredicateOps;
245+
246+ return IdBuilderResult{/* errorMsg=*/ errorMsg,
247+ /* mappingIdOps=*/ ids,
248+ /* predicateOps=*/ predicateOps};
211249 };
212250
213251 return res;
0 commit comments