@@ -156,6 +156,63 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
156156 return res;
157157}
158158
159+ // / Create a lane id builder that takes the `originalBasis` and decompose
160+ // / it in the basis of `forallMappingSizes`. The linear id builder returns an
161+ // / n-D vector of ids for indexing and 1-D size + id for predicate generation.
162+ static GpuIdBuilderFnType laneIdBuilderFn (int64_t periodicity) {
163+ auto res = [periodicity](RewriterBase &rewriter, Location loc,
164+ ArrayRef<int64_t > forallMappingSizes,
165+ ArrayRef<int64_t > originalBasis) {
166+ SmallVector<OpFoldResult> originalBasisOfr =
167+ getAsIndexOpFoldResult (rewriter.getContext (), originalBasis);
168+ OpFoldResult linearId =
169+ buildLinearId<ThreadIdOp>(rewriter, loc, originalBasisOfr);
170+ AffineExpr d0 = getAffineDimExpr (0 , rewriter.getContext ());
171+ linearId = affine::makeComposedFoldedAffineApply (
172+ rewriter, loc, d0 % periodicity, {linearId});
173+
174+ // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
175+ // "row-major" order.
176+ SmallVector<int64_t > reverseBasisSizes (llvm::reverse (forallMappingSizes));
177+ SmallVector<int64_t > strides = computeStrides (reverseBasisSizes);
178+ SmallVector<AffineExpr> delinearizingExprs = delinearize (d0, strides);
179+ SmallVector<Value> ids;
180+ // Reverse back to be in [0 .. n] order.
181+ for (AffineExpr e : llvm::reverse (delinearizingExprs)) {
182+ ids.push_back (
183+ affine::makeComposedAffineApply (rewriter, loc, e, {linearId}));
184+ }
185+
186+ // clang-format off
187+ LLVM_DEBUG (llvm::interleaveComma (reverseBasisSizes,
188+ DBGS () << " --delinearization basis: " );
189+ llvm::dbgs () << " \n " ;
190+ llvm::interleaveComma (strides,
191+ DBGS () << " --delinearization strides: " );
192+ llvm::dbgs () << " \n " ;
193+ llvm::interleaveComma (delinearizingExprs,
194+ DBGS () << " --delinearization exprs: " );
195+ llvm::dbgs () << " \n " ;
196+ llvm::interleaveComma (ids, DBGS () << " --ids: " );
197+ llvm::dbgs () << " \n " ;);
198+ // clang-format on
199+
200+ // Return n-D ids for indexing and 1-D size + id for predicate generation.
201+ return IdBuilderResult{
202+ /* mappingIdOps=*/ ids,
203+ /* availableMappingSizes=*/
204+ SmallVector<int64_t >{computeProduct (originalBasis)},
205+ // `forallMappingSizes` iterate in the scaled basis, they need to be
206+ // scaled back into the original basis to provide tight
207+ // activeMappingSizes quantities for predication.
208+ /* activeMappingSizes=*/
209+ SmallVector<int64_t >{computeProduct (forallMappingSizes)},
210+ /* activeIdOps=*/ SmallVector<Value>{linearId.get <Value>()}};
211+ };
212+
213+ return res;
214+ }
215+
159216namespace mlir {
160217namespace transform {
161218namespace gpu {
@@ -221,6 +278,16 @@ GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping)
221278 : common3DIdBuilderFn<ThreadIdOp>(/* multiplicity=*/ 1 );
222279}
223280
281+ GpuLaneIdBuilder::GpuLaneIdBuilder (MLIRContext *ctx, int64_t warpSize,
282+ bool unused)
283+ : GpuIdBuilder(ctx, /* useLinearMapping=*/ true ,
284+ [](MLIRContext *ctx, MappingId id) {
285+ return GPULaneMappingAttr::get (ctx, id);
286+ }),
287+ warpSize (warpSize) {
288+ idBuilder = laneIdBuilderFn (/* periodicity=*/ warpSize);
289+ }
290+
224291DiagnosedSilenceableFailure checkGpuLimits (TransformOpInterface transformOp,
225292 std::optional<int64_t > gridDimX,
226293 std::optional<int64_t > gridDimY,
0 commit comments