15
15
#include " mlir/Dialect/SCF/IR/SCF.h"
16
16
#include " mlir/Dialect/Transform/IR/TransformDialect.h"
17
17
#include " mlir/Dialect/Transform/IR/TransformInterfaces.h"
18
+ #include " mlir/IR/BlockAndValueMapping.h"
18
19
#include " mlir/IR/Diagnostics.h"
19
20
#include " mlir/IR/Value.h"
20
21
#include " llvm/ADT/None.h"
@@ -157,45 +158,75 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
157
158
SmallVectorImpl<Value> &)>
158
159
blockIdGenerator,
159
160
SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp) {
161
+ // Step 0. Target-specific verifications. There is no good place to anchor
162
+ // those right now: the ForeachThreadOp is target-independent and the
163
+ // transform op does not apply to individual ForeachThreadOp.
164
+ MLIRContext *ctx = foreachThreadOp->getContext ();
165
+ Location loc = foreachThreadOp->getLoc ();
166
+ Attribute bX = GPUBlockMappingAttr::get (ctx, Blocks::DimX);
167
+ Attribute bY = GPUBlockMappingAttr::get (ctx, Blocks::DimY);
168
+ Attribute bZ = GPUBlockMappingAttr::get (ctx, Blocks::DimZ);
160
169
if (foreachThreadOp.getNumResults () > 0 )
161
170
return transformOp.emitSilenceableError ()
162
- << " only bufferized scf.foreach_thread lowers to gpu.block_id" ;
171
+ << " only bufferized scf.foreach_thread lowers to "
172
+ " gpu.block_id" ;
163
173
if (foreachThreadOp.getNumThreads ().size () > 3 )
164
174
return transformOp.emitSilenceableError ()
165
- << " scf.foreach_thread with rank > 3 does not lower to gpu.block_id" ;
166
-
167
- // Step 0. Outline the compute workload region and set up the workload
168
- // operands.
169
- SmallVector<int64_t > mapping;
175
+ << " scf.foreach_thread with rank > 3 does not lower to "
176
+ " gpu.block_id" ;
177
+ if (llvm::any_of (foreachThreadOp.getNumThreads (), [](Value v) {
178
+ return !v.getDefiningOp <arith::ConstantIndexOp>();
179
+ })) {
180
+ return transformOp.emitSilenceableError ()
181
+ << " unsupported dynamic griddim size" ;
182
+ }
170
183
if (!foreachThreadOp.getMapping ().has_value ())
171
184
return transformOp.emitSilenceableError () << " mapping must be present" ;
172
- for (DeviceMappingAttrInterface map :
173
- foreachThreadOp.getMapping ()->getValue ()) {
174
- if (auto blockMap = map.dyn_cast <GPUBlockMappingAttr>()) {
175
- mapping.push_back ((int64_t )blockMap.getBlock ());
176
- } else {
177
- return transformOp.emitSilenceableError ()
178
- << " mapping must be #gpu.block<x/y/z/>" ;
179
- }
185
+ SmallVector<Attribute> blockMapping =
186
+ llvm::to_vector (foreachThreadOp.getMapping ()->getValue ());
187
+ if (llvm::any_of (blockMapping, [](DeviceMappingAttrInterface map) {
188
+ return !map.isa <GPUBlockMappingAttr>();
189
+ })) {
190
+ return transformOp.emitSilenceableError ()
191
+ << " mapping must be #gpu.block<x/y/z/>" ;
180
192
}
181
193
182
- FailureOr<SmallVector<OpFoldResult>> potentialGridDim =
183
- foreachThreadOp.getPermutedNumThreads (rewriter, mapping);
184
-
185
- if (failed (potentialGridDim) ||
186
- llvm::any_of (*potentialGridDim, [](OpFoldResult ofr) {
187
- return !getConstantIntValue (ofr).has_value ();
188
- })) {
189
- return transformOp.emitSilenceableError () << " unsupported dynamic gridDim" ;
194
+ // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
195
+ SmallVector<Value> numBlocks =
196
+ llvm::to_vector (foreachThreadOp.getNumThreads ());
197
+ // Ensure we have 3 block sizes, one for each id.
198
+ Value one;
199
+ for (auto attr : {bX, bY, bZ}) {
200
+ if (std::find (blockMapping.begin (), blockMapping.end (), attr) ==
201
+ blockMapping.end ()) {
202
+ blockMapping.push_back (attr);
203
+ one = one ? one : rewriter.create <arith::ConstantIndexOp>(loc, 1 );
204
+ numBlocks.push_back (one);
205
+ }
190
206
}
191
207
192
- for (OpFoldResult ofr : *potentialGridDim)
193
- gridDims.push_back (getConstantIntValue (ofr).value ());
208
+ // Step 2. sort the values by the corresponding GPUBlockMappingAttr.
209
+ auto comparator = [](Attribute a, Attribute b) -> bool {
210
+ return static_cast <int64_t >(a.cast <GPUBlockMappingAttr>().getBlock ()) <
211
+ static_cast <int64_t >(b.cast <GPUBlockMappingAttr>().getBlock ());
212
+ };
213
+ SmallVector<Value> gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey (
214
+ blockMapping, numBlocks, comparator);
215
+ for (Value v : gridDimValues)
216
+ gridDims.push_back (v.getDefiningOp <arith::ConstantIndexOp>().value ());
194
217
218
+ // Step 3. Generate the blockIds using the provided generator and map the
219
+ // induction variables to the newly created ops.
195
220
SmallVector<Value> blockOps;
196
221
blockIdGenerator (rewriter, foreachThreadOp, blockOps);
222
+ BlockAndValueMapping bvm;
223
+ for (auto [blockIdx, blockDim] :
224
+ llvm::zip (foreachThreadOp.getThreadIndices (), blockMapping)) {
225
+ bvm.map (blockIdx, blockOps[static_cast <int64_t >(
226
+ blockDim.cast <GPUBlockMappingAttr>().getBlock ())]);
227
+ }
197
228
198
- // Step 1 . Move the body of foreachThreadOp.
229
+ // Step 4 . Move the body of foreachThreadOp.
199
230
// Erase the terminator first, it will not be used since we are on buffers.
200
231
rewriter.eraseOp (foreachThreadOp.getTerminator ());
201
232
Block *targetBlock = foreachThreadOp->getBlock ();
@@ -204,20 +235,16 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
204
235
targetBlock->getOperations ().splice (insertionPoint,
205
236
sourceBlock.getOperations ());
206
237
207
- // Step 2. RAUW thread indices to thread ops.
208
- SmallVector<Value> threadIndices =
209
- *foreachThreadOp.getPermutedThreadIndices (mapping);
210
- assert (blockOps.size () == 3 && " 3 block id ops are required" );
211
- for (auto [blockIdx, blockOp] : llvm::zip (threadIndices, blockOps)) {
212
- Value val = blockIdx;
213
- Value blkOp = blockOp;
214
- if (!val)
215
- continue ;
216
- for (Operation *user : llvm::make_early_inc_range (val.getUsers ()))
217
- user->replaceUsesOfWith (val, blkOp);
238
+ // Step 5. RAUW thread indices to thread ops.
239
+ for (Value blockIdx : foreachThreadOp.getThreadIndices ()) {
240
+ for (Operation *user : llvm::make_early_inc_range (blockIdx.getUsers ())) {
241
+ rewriter.updateRootInPlace (user, [&]() {
242
+ user->replaceUsesOfWith (blockIdx, bvm.lookup (blockIdx));
243
+ });
244
+ }
218
245
}
219
246
220
- // Step 3 . Erase old op.
247
+ // Step 6 . Erase old op.
221
248
rewriter.eraseOp (foreachThreadOp);
222
249
223
250
return DiagnosedSilenceableFailure::success ();
@@ -252,11 +279,10 @@ static void generateGpuBlockIds(RewriterBase &rewriter,
252
279
OpBuilder::InsertionGuard guard (rewriter);
253
280
rewriter.setInsertionPoint (foreachOp);
254
281
IndexType indexType = rewriter.getIndexType ();
255
- SmallVector<Dimension> gpuDims{Dimension::x, Dimension::y, Dimension::z};
256
- for (int64_t idx : llvm::seq<int64_t >(0 , gpuDims.size ())) {
257
- blockOps.push_back (
258
- rewriter.create <BlockIdOp>(loc, indexType, gpuDims[idx]));
259
- }
282
+ blockOps = SmallVector<Value>{
283
+ rewriter.create <BlockIdOp>(loc, indexType, Dimension::x),
284
+ rewriter.create <BlockIdOp>(loc, indexType, Dimension::y),
285
+ rewriter.create <BlockIdOp>(loc, indexType, Dimension::z)};
260
286
}
261
287
262
288
DiagnosedSilenceableFailure
@@ -333,61 +359,89 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
333
359
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
334
360
const SmallVectorImpl<int64_t > &globalBlockDims, bool syncAfterDistribute,
335
361
llvm::Optional<TransformOpInterface> transformOp) {
362
+ // Step 0. Target-specific verifications. There is no good place to anchor
363
+ // those right now: the ForeachThreadOp is target-independent and the
364
+ // transform op does not apply to individual ForeachThreadOp.
336
365
auto failureHelper =
337
366
[&](const Twine &message) -> DiagnosedSilenceableFailure {
338
367
if (transformOp.has_value ()) {
339
368
return transformOp->emitSilenceableError () << message;
340
369
}
341
370
return emitDefiniteFailure (foreachThreadOp, message);
342
371
};
343
-
372
+ MLIRContext *ctx = foreachThreadOp->getContext ();
373
+ Location loc = foreachThreadOp->getLoc ();
374
+ Attribute tX = GPUThreadMappingAttr::get (ctx, Threads::DimX);
375
+ Attribute tY = GPUThreadMappingAttr::get (ctx, Threads::DimY);
376
+ Attribute tZ = GPUThreadMappingAttr::get (ctx, Threads::DimZ);
344
377
if (foreachThreadOp.getNumResults () > 0 )
345
378
return failureHelper (
346
379
" only bufferized scf.foreach_thread lowers to gpu.thread_id" );
347
-
348
380
if (foreachThreadOp.getNumThreads ().size () > 3 )
349
381
return failureHelper (
350
382
" scf.foreach_thread with rank > 3 does not lower to gpu.thread_id" );
351
-
352
- SmallVector<int64_t > mapping;
383
+ if (llvm::any_of (foreachThreadOp.getNumThreads (), [](Value v) {
384
+ return !v.getDefiningOp <arith::ConstantIndexOp>();
385
+ })) {
386
+ return failureHelper (" unsupported dynamic blockdim size" );
387
+ }
353
388
if (!foreachThreadOp.getMapping ().has_value ())
354
389
return failureHelper (" mapping must be present" );
355
- for (DeviceMappingAttrInterface map :
356
- foreachThreadOp.getMapping ()->getValue ()) {
357
- if (auto threadMap = map.dyn_cast <GPUThreadMappingAttr>()) {
358
- mapping.push_back ((int64_t )threadMap.getThread ());
359
- } else {
360
- return failureHelper (" mapping must be #gpu.thread<x/y/z/>" );
361
- }
362
- }
363
- FailureOr<SmallVector<OpFoldResult>> potentialBlockDim =
364
- foreachThreadOp.getPermutedNumThreads (rewriter, mapping);
365
- if (failed (potentialBlockDim) ||
366
- llvm::any_of (*potentialBlockDim, [](OpFoldResult ofr) {
367
- return !getConstantIntValue (ofr).has_value ();
390
+ SmallVector<Attribute> threadMapping =
391
+ llvm::to_vector (foreachThreadOp.getMapping ()->getValue ());
392
+ if (llvm::any_of (threadMapping, [](DeviceMappingAttrInterface map) {
393
+ return !map.isa <GPUThreadMappingAttr>();
368
394
})) {
369
- return failureHelper (" unsupported dynamic blockdim size" );
395
+ return transformOp->emitSilenceableError ()
396
+ << " mapping must be #gpu.thread<x/y/z/>" ;
370
397
}
371
398
372
- SmallVector<int64_t > blockDim =
373
- llvm::to_vector (llvm::map_range (*potentialBlockDim, [](OpFoldResult ofr) {
374
- return getConstantIntValue (ofr).value ();
399
+ // Step 1. Complete the threadMapping to a full mapping (with 1s) if
400
+ // necessary.
401
+ SmallVector<Value> numThreads =
402
+ llvm::to_vector (foreachThreadOp.getNumThreads ());
403
+ // Ensure we have 3 block sizes, one for each id.
404
+ Value one;
405
+ for (auto attr : {tX, tY, tZ}) {
406
+ if (std::find (threadMapping.begin (), threadMapping.end (), attr) ==
407
+ threadMapping.end ()) {
408
+ threadMapping.push_back (attr);
409
+ one = one ? one : rewriter.create <arith::ConstantIndexOp>(loc, 1 );
410
+ numThreads.push_back (one);
411
+ }
412
+ }
413
+
414
+ // Step 2. sort the values by the corresponding GPUThreadMappingAttr.
415
+ auto comparator = [](Attribute a, Attribute b) -> bool {
416
+ return static_cast <int64_t >(a.cast <GPUThreadMappingAttr>().getThread ()) <
417
+ static_cast <int64_t >(b.cast <GPUThreadMappingAttr>().getThread ());
418
+ };
419
+ SmallVector<Value> blockDimValues =
420
+ scf::ForeachThreadOp::getValuesSortedByKey (threadMapping, numThreads,
421
+ comparator);
422
+ SmallVector<int64_t > blockDims =
423
+ llvm::to_vector (llvm::map_range (blockDimValues, [](Value v) {
424
+ return v.getDefiningOp <arith::ConstantIndexOp>().value ();
375
425
}));
376
426
377
- // Step 1 . Create the gpu.thread ops
378
- Location loc = foreachThreadOp. getLoc ();
427
+ // Step 3 . Create the gpu.thread ops and map the induction variables to the
428
+ // newly created ops.
379
429
IndexType indexType = rewriter.getIndexType ();
380
-
381
- SmallVector<Dimension> gpuDims{Dimension::x, Dimension::y, Dimension::z};
382
- SmallVector<Value> threadOps;
383
- for (int64_t idx : llvm::seq<int64_t >(0 , blockDim.size ())) {
384
- threadOps.push_back (
385
- rewriter.create <ThreadIdOp>(loc, indexType, gpuDims[idx]));
430
+ SmallVector<Value> threadOps{
431
+ rewriter.create <ThreadIdOp>(loc, indexType, Dimension::x),
432
+ rewriter.create <ThreadIdOp>(loc, indexType, Dimension::y),
433
+ rewriter.create <ThreadIdOp>(loc, indexType, Dimension::z)};
434
+ BlockAndValueMapping bvm;
435
+ for (auto [blockIdx, blockDim] :
436
+ llvm::zip (foreachThreadOp.getThreadIndices (), threadMapping)) {
437
+ bvm.map (blockIdx, threadOps[static_cast <int64_t >(
438
+ blockDim.cast <GPUThreadMappingAttr>().getThread ())]);
386
439
}
387
- // Step 2. Maybe create conditionals to predicate the region.
440
+
441
+ // Step 4. Maybe create conditionals to predicate the region.
388
442
Value predicate;
389
443
for (auto [threadId, blockDim, globalBlockDim] :
390
- llvm::zip (threadOps, blockDim , globalBlockDims)) {
444
+ llvm::zip (threadOps, blockDims , globalBlockDims)) {
391
445
if (blockDim > globalBlockDim) {
392
446
return failureHelper (
393
447
" The requested GPU threads are fewer than the number of loop trip "
@@ -404,45 +458,41 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
404
458
: tmpPredicate;
405
459
}
406
460
407
- // Step 3 . Move the body of foreachThreadOp.
461
+ // Step 5 . Move the body of foreachThreadOp.
408
462
// Erase the terminator first, it will not be used.
409
463
rewriter.eraseOp (foreachThreadOp.getTerminator ());
410
464
Block *targetBlock;
411
465
Block::iterator insertionPoint;
412
466
if (predicate) {
413
- // Step 3 .a. If predicated, move at the beginning.
467
+ // Step 5 .a. If predicated, move at the beginning.
414
468
auto ifOp =
415
469
rewriter.create <scf::IfOp>(loc, predicate, /* withElseRegion=*/ false );
416
470
targetBlock = ifOp.thenBlock ();
417
471
insertionPoint = ifOp.thenBlock ()->begin ();
418
472
} else {
419
- // Step 3.a . Otherwise, move inline just before foreachThreadOp.
473
+ // Step 5.b . Otherwise, move inline just before foreachThreadOp.
420
474
targetBlock = foreachThreadOp->getBlock ();
421
475
insertionPoint = Block::iterator (foreachThreadOp);
422
476
}
423
477
Block &sourceBlock = foreachThreadOp.getRegion ().front ();
424
478
targetBlock->getOperations ().splice (insertionPoint,
425
479
sourceBlock.getOperations ());
426
480
427
- // Step 4. RAUW thread indices to thread ops.
428
- SmallVector<Value> threadIndices =
429
- *foreachThreadOp.getPermutedThreadIndices (mapping);
430
- for (auto [threadIdx, threadOp] : llvm::zip (threadIndices, threadOps)) {
431
- Value val = threadIdx;
432
- Value op = threadOp;
433
- if (!val)
434
- continue ;
435
- for (Operation *user : llvm::make_early_inc_range (val.getUsers ())) {
436
- user->replaceUsesOfWith (val, op);
481
+ // Step 6. RAUW thread indices to thread ops.
482
+ for (Value threadIdx : foreachThreadOp.getThreadIndices ()) {
483
+ for (Operation *user : llvm::make_early_inc_range (threadIdx.getUsers ())) {
484
+ rewriter.updateRootInPlace (user, [&]() {
485
+ user->replaceUsesOfWith (threadIdx, bvm.lookup (threadIdx));
486
+ });
437
487
}
438
488
}
439
489
440
- // Step 5 . syncthreads.
490
+ // Step 7 . syncthreads.
441
491
// TODO: Need warpsync
442
492
if (syncAfterDistribute)
443
493
rewriter.create <BarrierOp>(loc);
444
494
445
- // Step 6 . Erase old op.
495
+ // Step 8 . Erase old op.
446
496
rewriter.eraseOp (foreachThreadOp);
447
497
448
498
return DiagnosedSilenceableFailure::success ();
0 commit comments