Skip to content

Commit 4c372a3

Browse files
committed
[mlir] Make GpuAsyncRegion pass depend on async dialect.
Do not cache gpu.async.token type so that the pass can be created before the GPU dialect is registered. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D94397
1 parent 4fe7b16 commit 4c372a3

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

mlir/include/mlir/Dialect/GPU/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def GpuKernelOutlining : Pass<"gpu-kernel-outlining", "ModuleOp"> {
1919
def GpuAsyncRegionPass : FunctionPass<"gpu-async-region"> {
2020
let summary = "Make GPU ops async";
2121
let constructor = "mlir::createGpuAsyncRegionPass()";
22+
let dependentDialects = ["async::AsyncDialect"];
2223
}
2324

2425
#endif // MLIR_DIALECT_GPU_PASSES

mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
7878
if (op->getNumRegions() > 0)
7979
return op->emitOpError("regions are not supported");
8080

81+
auto tokenType = builder.getType<gpu::AsyncTokenType>();
82+
8183
// If there is no current token, insert a `gpu.wait async` without
8284
// dependencies to create one.
8385
if (!currentToken)
@@ -108,7 +110,7 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
108110
}
109111

110112
OpBuilder builder;
111-
const Type tokenType = builder.getType<gpu::AsyncTokenType>();
113+
112114
// The token that represents the current asynchronous dependency. It's valid
113115
// range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
114116
// In between, each gpu::AsyncOpInterface depends on the current token and

mlir/lib/Dialect/GPU/Transforms/PassDetail.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef DIALECT_GPU_TRANSFORMS_PASSDETAIL_H_
1010
#define DIALECT_GPU_TRANSFORMS_PASSDETAIL_H_
1111

12+
#include "mlir/Dialect/Async/IR/Async.h"
1213
#include "mlir/Pass/Pass.h"
1314

1415
namespace mlir {

0 commit comments

Comments
 (0)