Skip to content

Commit f0a411d

Browse files
[mlir][Transform]Significantly cleanup scf.foreach_thread and GPU transform permutation handling
Previously, the need for a dense permutation leaked into the thread_dim_mapping specification. This revision allows to use a sparse specification of the thread_dim_mapping and the proper completion / sorting is applied automatically. In the process, the sematics of scf.foreach_thread is tightened to require a matching number of thread dimensions and mappings. The relevant negative test is added. Differential Revision: https://reviews.llvm.org/D137906
1 parent 87f652d commit f0a411d

File tree

6 files changed

+195
-166
lines changed

6 files changed

+195
-166
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -536,15 +536,14 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
536536
return getBody()->getArguments().drop_front(getRank());
537537
}
538538

539-
/// Return the thread indices in the order specified by the
540-
/// given mapping argument. Return failure is
541-
/// mapping is not a valid permutation.
542-
FailureOr<SmallVector<Value>> getPermutedThreadIndices(ArrayRef<int64_t> mapping);
543-
544-
/// Return the number of threads in the order specified by the
545-
/// given mapping argument.
546-
/// Return failure is mapping is not a valid permutation.
547-
FailureOr<SmallVector<OpFoldResult>> getPermutedNumThreads(OpBuilder &b, ArrayRef<int64_t> mapping);
539+
/// Helper to sort `values` according to matching `keys`.
540+
/// Take a custom `compare` binary comparator which returns true if the first
541+
/// element is smaller than the second (i.e. compatible with std::sort).
542+
/// This is a helper typically used to sort numThreads values before they are
543+
/// mapped to concrete physical dimensions of hardware.
544+
static SmallVector<Value> getValuesSortedByKey(
545+
ArrayRef<Attribute> keys, ValueRange values,
546+
llvm::function_ref<bool(Attribute, Attribute)> compare);
548547

549548
// The ensureTerminator method generated by SingleBlockImplicitTerminator is
550549
// unaware of the fact that our terminator also needs a region to be

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 138 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/SCF/IR/SCF.h"
1616
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1717
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
18+
#include "mlir/IR/BlockAndValueMapping.h"
1819
#include "mlir/IR/Diagnostics.h"
1920
#include "mlir/IR/Value.h"
2021
#include "llvm/ADT/None.h"
@@ -157,45 +158,75 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
157158
SmallVectorImpl<Value> &)>
158159
blockIdGenerator,
159160
SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp) {
161+
// Step 0. Target-specific verifications. There is no good place to anchor
162+
// those right now: the ForeachThreadOp is target-independent and the
163+
// transform op does not apply to individual ForeachThreadOp.
164+
MLIRContext *ctx = foreachThreadOp->getContext();
165+
Location loc = foreachThreadOp->getLoc();
166+
Attribute bX = GPUBlockMappingAttr::get(ctx, Blocks::DimX);
167+
Attribute bY = GPUBlockMappingAttr::get(ctx, Blocks::DimY);
168+
Attribute bZ = GPUBlockMappingAttr::get(ctx, Blocks::DimZ);
160169
if (foreachThreadOp.getNumResults() > 0)
161170
return transformOp.emitSilenceableError()
162-
<< "only bufferized scf.foreach_thread lowers to gpu.block_id";
171+
<< "only bufferized scf.foreach_thread lowers to "
172+
"gpu.block_id";
163173
if (foreachThreadOp.getNumThreads().size() > 3)
164174
return transformOp.emitSilenceableError()
165-
<< "scf.foreach_thread with rank > 3 does not lower to gpu.block_id";
166-
167-
// Step 0. Outline the compute workload region and set up the workload
168-
// operands.
169-
SmallVector<int64_t> mapping;
175+
<< "scf.foreach_thread with rank > 3 does not lower to "
176+
"gpu.block_id";
177+
if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
178+
return !v.getDefiningOp<arith::ConstantIndexOp>();
179+
})) {
180+
return transformOp.emitSilenceableError()
181+
<< "unsupported dynamic griddim size";
182+
}
170183
if (!foreachThreadOp.getMapping().has_value())
171184
return transformOp.emitSilenceableError() << "mapping must be present";
172-
for (DeviceMappingAttrInterface map :
173-
foreachThreadOp.getMapping()->getValue()) {
174-
if (auto blockMap = map.dyn_cast<GPUBlockMappingAttr>()) {
175-
mapping.push_back((int64_t)blockMap.getBlock());
176-
} else {
177-
return transformOp.emitSilenceableError()
178-
<< "mapping must be #gpu.block<x/y/z/>";
179-
}
185+
SmallVector<Attribute> blockMapping =
186+
llvm::to_vector(foreachThreadOp.getMapping()->getValue());
187+
if (llvm::any_of(blockMapping, [](DeviceMappingAttrInterface map) {
188+
return !map.isa<GPUBlockMappingAttr>();
189+
})) {
190+
return transformOp.emitSilenceableError()
191+
<< "mapping must be #gpu.block<x/y/z/>";
180192
}
181193

182-
FailureOr<SmallVector<OpFoldResult>> potentialGridDim =
183-
foreachThreadOp.getPermutedNumThreads(rewriter, mapping);
184-
185-
if (failed(potentialGridDim) ||
186-
llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) {
187-
return !getConstantIntValue(ofr).has_value();
188-
})) {
189-
return transformOp.emitSilenceableError() << "unsupported dynamic gridDim";
194+
// Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
195+
SmallVector<Value> numBlocks =
196+
llvm::to_vector(foreachThreadOp.getNumThreads());
197+
// Ensure we have 3 block sizes, one for each id.
198+
Value one;
199+
for (auto attr : {bX, bY, bZ}) {
200+
if (std::find(blockMapping.begin(), blockMapping.end(), attr) ==
201+
blockMapping.end()) {
202+
blockMapping.push_back(attr);
203+
one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
204+
numBlocks.push_back(one);
205+
}
190206
}
191207

192-
for (OpFoldResult ofr : *potentialGridDim)
193-
gridDims.push_back(getConstantIntValue(ofr).value());
208+
// Step 2. sort the values by the corresponding GPUBlockMappingAttr.
209+
auto comparator = [](Attribute a, Attribute b) -> bool {
210+
return static_cast<int64_t>(a.cast<GPUBlockMappingAttr>().getBlock()) <
211+
static_cast<int64_t>(b.cast<GPUBlockMappingAttr>().getBlock());
212+
};
213+
SmallVector<Value> gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey(
214+
blockMapping, numBlocks, comparator);
215+
for (Value v : gridDimValues)
216+
gridDims.push_back(v.getDefiningOp<arith::ConstantIndexOp>().value());
194217

218+
// Step 3. Generate the blockIds using the provided generator and map the
219+
// induction variables to the newly created ops.
195220
SmallVector<Value> blockOps;
196221
blockIdGenerator(rewriter, foreachThreadOp, blockOps);
222+
BlockAndValueMapping bvm;
223+
for (auto [blockIdx, blockDim] :
224+
llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) {
225+
bvm.map(blockIdx, blockOps[static_cast<int64_t>(
226+
blockDim.cast<GPUBlockMappingAttr>().getBlock())]);
227+
}
197228

198-
// Step 1. Move the body of foreachThreadOp.
229+
// Step 4. Move the body of foreachThreadOp.
199230
// Erase the terminator first, it will not be used since we are on buffers.
200231
rewriter.eraseOp(foreachThreadOp.getTerminator());
201232
Block *targetBlock = foreachThreadOp->getBlock();
@@ -204,20 +235,16 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
204235
targetBlock->getOperations().splice(insertionPoint,
205236
sourceBlock.getOperations());
206237

207-
// Step 2. RAUW thread indices to thread ops.
208-
SmallVector<Value> threadIndices =
209-
*foreachThreadOp.getPermutedThreadIndices(mapping);
210-
assert(blockOps.size() == 3 && "3 block id ops are required");
211-
for (auto [blockIdx, blockOp] : llvm::zip(threadIndices, blockOps)) {
212-
Value val = blockIdx;
213-
Value blkOp = blockOp;
214-
if (!val)
215-
continue;
216-
for (Operation *user : llvm::make_early_inc_range(val.getUsers()))
217-
user->replaceUsesOfWith(val, blkOp);
238+
// Step 5. RAUW thread indices to thread ops.
239+
for (Value blockIdx : foreachThreadOp.getThreadIndices()) {
240+
for (Operation *user : llvm::make_early_inc_range(blockIdx.getUsers())) {
241+
rewriter.updateRootInPlace(user, [&]() {
242+
user->replaceUsesOfWith(blockIdx, bvm.lookup(blockIdx));
243+
});
244+
}
218245
}
219246

220-
// Step 3. Erase old op.
247+
// Step 6. Erase old op.
221248
rewriter.eraseOp(foreachThreadOp);
222249

223250
return DiagnosedSilenceableFailure::success();
@@ -252,11 +279,10 @@ static void generateGpuBlockIds(RewriterBase &rewriter,
252279
OpBuilder::InsertionGuard guard(rewriter);
253280
rewriter.setInsertionPoint(foreachOp);
254281
IndexType indexType = rewriter.getIndexType();
255-
SmallVector<Dimension> gpuDims{Dimension::x, Dimension::y, Dimension::z};
256-
for (int64_t idx : llvm::seq<int64_t>(0, gpuDims.size())) {
257-
blockOps.push_back(
258-
rewriter.create<BlockIdOp>(loc, indexType, gpuDims[idx]));
259-
}
282+
blockOps = SmallVector<Value>{
283+
rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
284+
rewriter.create<BlockIdOp>(loc, indexType, Dimension::y),
285+
rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)};
260286
}
261287

262288
DiagnosedSilenceableFailure
@@ -333,61 +359,89 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
333359
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
334360
const SmallVectorImpl<int64_t> &globalBlockDims, bool syncAfterDistribute,
335361
llvm::Optional<TransformOpInterface> transformOp) {
362+
// Step 0. Target-specific verifications. There is no good place to anchor
363+
// those right now: the ForeachThreadOp is target-independent and the
364+
// transform op does not apply to individual ForeachThreadOp.
336365
auto failureHelper =
337366
[&](const Twine &message) -> DiagnosedSilenceableFailure {
338367
if (transformOp.has_value()) {
339368
return transformOp->emitSilenceableError() << message;
340369
}
341370
return emitDefiniteFailure(foreachThreadOp, message);
342371
};
343-
372+
MLIRContext *ctx = foreachThreadOp->getContext();
373+
Location loc = foreachThreadOp->getLoc();
374+
Attribute tX = GPUThreadMappingAttr::get(ctx, Threads::DimX);
375+
Attribute tY = GPUThreadMappingAttr::get(ctx, Threads::DimY);
376+
Attribute tZ = GPUThreadMappingAttr::get(ctx, Threads::DimZ);
344377
if (foreachThreadOp.getNumResults() > 0)
345378
return failureHelper(
346379
"only bufferized scf.foreach_thread lowers to gpu.thread_id");
347-
348380
if (foreachThreadOp.getNumThreads().size() > 3)
349381
return failureHelper(
350382
"scf.foreach_thread with rank > 3 does not lower to gpu.thread_id");
351-
352-
SmallVector<int64_t> mapping;
383+
if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
384+
return !v.getDefiningOp<arith::ConstantIndexOp>();
385+
})) {
386+
return failureHelper("unsupported dynamic blockdim size");
387+
}
353388
if (!foreachThreadOp.getMapping().has_value())
354389
return failureHelper("mapping must be present");
355-
for (DeviceMappingAttrInterface map :
356-
foreachThreadOp.getMapping()->getValue()) {
357-
if (auto threadMap = map.dyn_cast<GPUThreadMappingAttr>()) {
358-
mapping.push_back((int64_t)threadMap.getThread());
359-
} else {
360-
return failureHelper("mapping must be #gpu.thread<x/y/z/>");
361-
}
362-
}
363-
FailureOr<SmallVector<OpFoldResult>> potentialBlockDim =
364-
foreachThreadOp.getPermutedNumThreads(rewriter, mapping);
365-
if (failed(potentialBlockDim) ||
366-
llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) {
367-
return !getConstantIntValue(ofr).has_value();
390+
SmallVector<Attribute> threadMapping =
391+
llvm::to_vector(foreachThreadOp.getMapping()->getValue());
392+
if (llvm::any_of(threadMapping, [](DeviceMappingAttrInterface map) {
393+
return !map.isa<GPUThreadMappingAttr>();
368394
})) {
369-
return failureHelper("unsupported dynamic blockdim size");
395+
return transformOp->emitSilenceableError()
396+
<< "mapping must be #gpu.thread<x/y/z/>";
370397
}
371398

372-
SmallVector<int64_t> blockDim =
373-
llvm::to_vector(llvm::map_range(*potentialBlockDim, [](OpFoldResult ofr) {
374-
return getConstantIntValue(ofr).value();
399+
// Step 1. Complete the threadMapping to a full mapping (with 1s) if
400+
// necessary.
401+
SmallVector<Value> numThreads =
402+
llvm::to_vector(foreachThreadOp.getNumThreads());
403+
// Ensure we have 3 block sizes, one for each id.
404+
Value one;
405+
for (auto attr : {tX, tY, tZ}) {
406+
if (std::find(threadMapping.begin(), threadMapping.end(), attr) ==
407+
threadMapping.end()) {
408+
threadMapping.push_back(attr);
409+
one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
410+
numThreads.push_back(one);
411+
}
412+
}
413+
414+
// Step 2. sort the values by the corresponding GPUThreadMappingAttr.
415+
auto comparator = [](Attribute a, Attribute b) -> bool {
416+
return static_cast<int64_t>(a.cast<GPUThreadMappingAttr>().getThread()) <
417+
static_cast<int64_t>(b.cast<GPUThreadMappingAttr>().getThread());
418+
};
419+
SmallVector<Value> blockDimValues =
420+
scf::ForeachThreadOp::getValuesSortedByKey(threadMapping, numThreads,
421+
comparator);
422+
SmallVector<int64_t> blockDims =
423+
llvm::to_vector(llvm::map_range(blockDimValues, [](Value v) {
424+
return v.getDefiningOp<arith::ConstantIndexOp>().value();
375425
}));
376426

377-
// Step 1. Create the gpu.thread ops
378-
Location loc = foreachThreadOp.getLoc();
427+
// Step 3. Create the gpu.thread ops and map the induction variables to the
428+
// newly created ops.
379429
IndexType indexType = rewriter.getIndexType();
380-
381-
SmallVector<Dimension> gpuDims{Dimension::x, Dimension::y, Dimension::z};
382-
SmallVector<Value> threadOps;
383-
for (int64_t idx : llvm::seq<int64_t>(0, blockDim.size())) {
384-
threadOps.push_back(
385-
rewriter.create<ThreadIdOp>(loc, indexType, gpuDims[idx]));
430+
SmallVector<Value> threadOps{
431+
rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
432+
rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
433+
rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)};
434+
BlockAndValueMapping bvm;
435+
for (auto [blockIdx, blockDim] :
436+
llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
437+
bvm.map(blockIdx, threadOps[static_cast<int64_t>(
438+
blockDim.cast<GPUThreadMappingAttr>().getThread())]);
386439
}
387-
// Step 2. Maybe create conditionals to predicate the region.
440+
441+
// Step 4. Maybe create conditionals to predicate the region.
388442
Value predicate;
389443
for (auto [threadId, blockDim, globalBlockDim] :
390-
llvm::zip(threadOps, blockDim, globalBlockDims)) {
444+
llvm::zip(threadOps, blockDims, globalBlockDims)) {
391445
if (blockDim > globalBlockDim) {
392446
return failureHelper(
393447
"The requested GPU threads are fewer than the number of loop trip "
@@ -404,45 +458,41 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
404458
: tmpPredicate;
405459
}
406460

407-
// Step 3. Move the body of foreachThreadOp.
461+
// Step 5. Move the body of foreachThreadOp.
408462
// Erase the terminator first, it will not be used.
409463
rewriter.eraseOp(foreachThreadOp.getTerminator());
410464
Block *targetBlock;
411465
Block::iterator insertionPoint;
412466
if (predicate) {
413-
// Step 3.a. If predicated, move at the beginning.
467+
// Step 5.a. If predicated, move at the beginning.
414468
auto ifOp =
415469
rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
416470
targetBlock = ifOp.thenBlock();
417471
insertionPoint = ifOp.thenBlock()->begin();
418472
} else {
419-
// Step 3.a. Otherwise, move inline just before foreachThreadOp.
473+
// Step 5.b. Otherwise, move inline just before foreachThreadOp.
420474
targetBlock = foreachThreadOp->getBlock();
421475
insertionPoint = Block::iterator(foreachThreadOp);
422476
}
423477
Block &sourceBlock = foreachThreadOp.getRegion().front();
424478
targetBlock->getOperations().splice(insertionPoint,
425479
sourceBlock.getOperations());
426480

427-
// Step 4. RAUW thread indices to thread ops.
428-
SmallVector<Value> threadIndices =
429-
*foreachThreadOp.getPermutedThreadIndices(mapping);
430-
for (auto [threadIdx, threadOp] : llvm::zip(threadIndices, threadOps)) {
431-
Value val = threadIdx;
432-
Value op = threadOp;
433-
if (!val)
434-
continue;
435-
for (Operation *user : llvm::make_early_inc_range(val.getUsers())) {
436-
user->replaceUsesOfWith(val, op);
481+
// Step 6. RAUW thread indices to thread ops.
482+
for (Value threadIdx : foreachThreadOp.getThreadIndices()) {
483+
for (Operation *user : llvm::make_early_inc_range(threadIdx.getUsers())) {
484+
rewriter.updateRootInPlace(user, [&]() {
485+
user->replaceUsesOfWith(threadIdx, bvm.lookup(threadIdx));
486+
});
437487
}
438488
}
439489

440-
// Step 5. syncthreads.
490+
// Step 7. syncthreads.
441491
// TODO: Need warpsync
442492
if (syncAfterDistribute)
443493
rewriter.create<BarrierOp>(loc);
444494

445-
// Step 6. Erase old op.
495+
// Step 8. Erase old op.
446496
rewriter.eraseOp(foreachThreadOp);
447497

448498
return DiagnosedSilenceableFailure::success();

0 commit comments

Comments
 (0)