Skip to content

Commit 04b7d6f

Browse files
authored
[gpu] Insert gpu waits for gpu allocs and kernel launch (#196)
1 parent 4069666 commit 04b7d6f

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

mlir/lib/Conversion/gpu_runtime_to_llvm.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ struct FunctionCallBuilder {
8686
mlir::LLVM::LLVMFunctionType functionType;
8787
};
8888

89-
static const char *kEventCountAttrName = "gpu.event_count";
90-
static const char *kEventIndexAttrName = "gpu.event_index";
89+
static constexpr llvm::StringLiteral kEventCountAttrName("gpu.event_count");
90+
static constexpr llvm::StringLiteral kEventIndexAttrName("gpu.event_index");
9191

9292
template <typename OpTy>
9393
class ConvertOpToGpuRuntimeCallPattern
@@ -559,12 +559,13 @@ class ConvertGpuKernelLaunchPattern
559559
eventIndexVar,
560560
// clang-format on
561561
};
562-
auto res = launchKernelCallBuilder.create(loc, rewriter, params);
562+
auto event =
563+
launchKernelCallBuilder.create(loc, rewriter, params)->getResult(0);
563564
if (op.getNumResults() == 0) {
565+
waitEventCallBuilder.create(loc, rewriter, event);
564566
rewriter.eraseOp(op);
565567
} else {
566-
assert(res.getNumResults() == op.getNumResults());
567-
rewriter.replaceOp(op, res.getResults());
568+
rewriter.replaceOp(op, event);
568569
}
569570
return mlir::success();
570571
}
@@ -659,11 +660,12 @@ class ConvertGpuAllocPattern
659660
}
660661

661662
mlir::Value resMemref = memrefDesc;
663+
mlir::Value event = rewriter.create<mlir::LLVM::ExtractValueOp>(
664+
loc, llvmPointerType, res, rewriter.getI64ArrayAttr(2));
662665
if (op.getNumResults() == 1) {
666+
waitEventCallBuilder.create(loc, rewriter, event);
663667
rewriter.replaceOp(op, resMemref);
664668
} else {
665-
auto event = rewriter.create<mlir::LLVM::ExtractValueOp>(
666-
loc, llvmPointerType, res, rewriter.getI64ArrayAttr(2));
667669
mlir::Value vals[] = {
668670
resMemref,
669671
event,
@@ -782,16 +784,16 @@ struct EnumerateEventsPass
782784
void runOnOperation() override {
783785
auto mod = getOperation();
784786
int64_t eventCount = 0;
785-
auto intType = mlir::IntegerType::get(&getContext(), 64);
787+
auto *ctx = &getContext();
788+
auto intType = mlir::IntegerType::get(ctx, 64);
789+
auto indexAttrName = mlir::StringAttr::get(ctx, kEventIndexAttrName);
790+
auto countAttrName = mlir::StringAttr::get(ctx, kEventCountAttrName);
786791
mod.walk([&](mlir::gpu::AsyncOpInterface op) {
787-
if (op.getAsyncToken()) {
788-
op->setAttr(kEventIndexAttrName,
789-
mlir::IntegerAttr::get(intType, eventCount));
790-
++eventCount;
791-
}
792+
// if (op.getAsyncToken())
793+
op->setAttr(indexAttrName, mlir::IntegerAttr::get(intType, eventCount));
794+
++eventCount;
792795
});
793-
mod->setAttr(kEventCountAttrName,
794-
mlir::IntegerAttr::get(intType, eventCount));
796+
mod->setAttr(countAttrName, mlir::IntegerAttr::get(intType, eventCount));
795797
}
796798
};
797799

0 commit comments

Comments
 (0)