Skip to content

Commit 2685ba6

Browse files
authored
Add cudart-to-hiprt conversion (#2016)
* add cudart-to-hiprt conversion, map cudaFree to hipFree * fix * fmt * fmt * add pass in raise pipeline * fix
1 parent 3f01322 commit 2685ba6

File tree

4 files changed

+82
-8
lines changed

4 files changed

+82
-8
lines changed

src/enzyme_ad/jax/Passes/ParallelLower.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include "llvm/ADT/SmallPtrSet.h"
3232
#include "llvm/ADT/StringRef.h"
3333
#include "llvm/Support/DebugLog.h"
34+
#include <llvm/ADT/SmallVector.h>
35+
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
3436

3537
#include "Enzyme/MLIR/Dialect/Ops.h"
3638
#include "Enzyme/MLIR/Passes/Passes.h"
@@ -45,6 +47,7 @@ namespace enzyme {
4547
#define GEN_PASS_DEF_PARALLELLOWER
4648
#define GEN_PASS_DEF_FIXGPUFUNC
4749
#define GEN_PASS_DEF_STRIPGPUINFO
50+
#define GEN_PASS_DEF_CONVERTCUDARTTOHIPRT
4851
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
4952
} // namespace enzyme
5053
} // namespace mlir
@@ -109,11 +112,13 @@ struct ConvertCudaRTtoCPU : public ConvertCudaRTtoCPUBase<ConvertCudaRTtoCPU> {
109112
struct ConvertCudaRTtoGPU : public ConvertCudaRTtoGPUBase<ConvertCudaRTtoGPU> {
110113
void runOnOperation() override;
111114
};
115+
*/
116+
112117
struct ConvertCudaRTtoHipRT
113-
: public ConvertCudaRTtoHipRTBase<ConvertCudaRTtoHipRT> {
118+
: public enzyme::impl::ConvertCudaRTtoHipRTBase<ConvertCudaRTtoHipRT> {
114119
void runOnOperation() override;
115120
};
116-
*/
121+
117122
struct FixGPUFunc : public enzyme::impl::FixGPUFuncBase<FixGPUFunc> {
118123
using FixGPUFuncBase::FixGPUFuncBase;
119124
void runOnOperation() override;
@@ -1330,18 +1335,19 @@ static void setCallee(LLVM::CallOp call, StringRef symName) {
13301335
call.setCallee(symName);
13311336
}
13321337
template <typename CallOpTy, typename FuncOpTy>
1333-
void replaceCallOp(ModuleOp m, CallOpTy call, llvm::StringRef callee) {
1334-
auto loc = call->getLoc();
1335-
OpBuilder moduleBuilder = OpBuilder::atBlockEnd(m.getBody());
1338+
void replaceCallOp(ModuleOp m, CallOpTy call, llvm::StringRef callee,
1339+
SmallPtrSetImpl<Operation *> &toErase) {
13361340
OpBuilder callBuilder(call);
13371341
auto funcOp = m.lookupSymbol<FuncOpTy>(callee);
13381342
if (isHipCallEquivalent(callee)) {
13391343
assert(funcOp);
13401344
auto hipName = getHipName(callee);
13411345
if (!m.lookupSymbol<FuncOpTy>(hipName)) {
1346+
OpBuilder moduleBuilder(funcOp.getOperation());
13421347
auto hipFuncOp =
13431348
cast<FuncOpTy>(moduleBuilder.clone(*funcOp.getOperation()));
13441349
hipFuncOp.setSymName(hipName);
1350+
toErase.insert(funcOp.getOperation());
13451351
}
13461352
setCallee(call, hipName);
13471353
} else {
@@ -1351,24 +1357,30 @@ void replaceCallOp(ModuleOp m, CallOpTy call, llvm::StringRef callee) {
13511357
}
13521358
}
13531359

1354-
#if 0
13551360
void ConvertCudaRTtoHipRT::runOnOperation() {
1361+
SmallPtrSet<Operation *, 8> toErase;
1362+
13561363
getOperation().walk([&](LLVM::CallOp call) {
13571364
if (!call.getCallee())
13581365
return;
13591366
auto name = *call.getCallee();
13601367
if (!isCudartCall(name))
13611368
return;
1362-
replaceCallOp<LLVM::CallOp, LLVM::LLVMFuncOp>(getOperation(), call, name);
1369+
replaceCallOp<LLVM::CallOp, LLVM::LLVMFuncOp>(getOperation(), call, name,
1370+
toErase);
13631371
});
13641372

13651373
getOperation().walk([&](CallOp call) {
13661374
auto name = call.getCallee();
13671375
if (!isCudartCall(name))
13681376
return;
1369-
replaceCallOp<CallOp, func::FuncOp>(getOperation(), call, name);
1377+
replaceCallOp<CallOp, func::FuncOp>(getOperation(), call, name, toErase);
13701378
});
13711379

1380+
// Erase old CUDA function declarations after all calls are updated
1381+
for (Operation *op : toErase)
1382+
op->erase();
1383+
13721384
OpBuilder builder(&getContext());
13731385
getOperation().walk([&](mlir::NVVM::Barrier0Op op) {
13741386
builder.setInsertionPoint(op);
@@ -1377,6 +1389,7 @@ void ConvertCudaRTtoHipRT::runOnOperation() {
13771389
});
13781390
}
13791391

1392+
#if 0
13801393
void ConvertCudaRTtoGPU::runOnOperation() {
13811394
std::function<void(Operation * call, llvm::StringRef callee)> replaceWithOp =
13821395
[&](Operation *call, llvm::StringRef callee) {

src/enzyme_ad/jax/Passes/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,13 @@ def ParallelLower : Pass<"parallel-lower"> {
10651065
];
10661066
}
10671067

1068+
def ConvertCudaRTtoHipRT : Pass<"convert-cudart-to-hiprt", "mlir::ModuleOp"> {
1069+
let summary = "Convert CUDA runtime calls to HIP runtime calls";
1070+
let dependentDialects = [
1071+
"mlir::ROCDL::ROCDLDialect",
1072+
];
1073+
}
1074+
10681075
def SCFParallelLoopUnroll : Pass<"scf-parallel-loop-unroll"> {
10691076
let summary = "Unroll and interleave scf parallel loops";
10701077
let dependentDialects = [

src/enzyme_ad/jax/raise.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ extern "C" std::string runLLVMToMLIRRoundTrip(std::string input,
131131
pass_pipeline += "print{filename="+outfile+".mlir},";
132132
}
133133
pass_pipeline += "symbol-dce,enzyme,remove-unnecessary-enzyme-ops,lower-affine";
134+
if (backend == "rocm")
135+
pass_pipeline += ",convert-cudart-to-hiprt";
134136
if (backend != "cpu")
135137
pass_pipeline += ",convert-parallel-to-gpu1,gpu-kernel-outlining,canonicalize,convert-parallel-to-gpu2{backend=";
136138
pass_pipeline += backend;
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(convert-cudart-to-hiprt)" | FileCheck %s
2+
3+
module {
4+
llvm.func @cudaMalloc(!llvm.ptr, i64) -> i32
5+
llvm.func @cudaFree(!llvm.ptr) -> i32
6+
llvm.func @cudaMemcpy(!llvm.ptr, !llvm.ptr, i64, i32) -> i32
7+
llvm.func @cudaDeviceSynchronize() -> i32
8+
llvm.func @cudaMemset(!llvm.ptr, i32, i64) -> i32
9+
llvm.func @cudaGetLastError() -> i32
10+
11+
llvm.func @test_llvm_cuda_calls(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64) -> i32 {
12+
%c0 = llvm.mlir.constant(0 : i32) : i32
13+
%c1 = llvm.mlir.constant(1 : i32) : i32
14+
15+
%0 = llvm.call @cudaMalloc(%arg0, %arg2) : (!llvm.ptr, i64) -> i32
16+
%1 = llvm.call @cudaMemcpy(%arg0, %arg1, %arg2, %c1) : (!llvm.ptr, !llvm.ptr, i64, i32) -> i32
17+
%2 = llvm.call @cudaMemset(%arg0, %c0, %arg2) : (!llvm.ptr, i32, i64) -> i32
18+
%3 = llvm.call @cudaDeviceSynchronize() : () -> i32
19+
%4 = llvm.call @cudaFree(%arg0) : (!llvm.ptr) -> i32
20+
%5 = llvm.call @cudaGetLastError() : () -> i32
21+
22+
llvm.return %c0 : i32
23+
}
24+
25+
llvm.func @test_nvvm_barrier_conversion(%arg0: !llvm.ptr) {
26+
%0 = llvm.mlir.constant(42 : i32) : i32
27+
llvm.store %0, %arg0 : i32, !llvm.ptr
28+
nvvm.barrier0
29+
%1 = llvm.load %arg0 : !llvm.ptr -> i32
30+
llvm.return
31+
}
32+
}
33+
34+
// CHECK-DAG: llvm.func @hipMalloc(!llvm.ptr, i64) -> i32
35+
// CHECK-DAG: llvm.func @hipFree(!llvm.ptr) -> i32
36+
// CHECK-DAG: llvm.func @hipMemcpy(!llvm.ptr, !llvm.ptr, i64, i32) -> i32
37+
// CHECK-DAG: llvm.func @hipDeviceSynchronize() -> i32
38+
// CHECK-DAG: llvm.func @hipMemset(!llvm.ptr, i32, i64) -> i32
39+
// CHECK-DAG: llvm.func @hipGetLastError() -> i32
40+
41+
// CHECK-LABEL: llvm.func @test_llvm_cuda_calls
42+
// CHECK: llvm.call @hipMalloc
43+
// CHECK: llvm.call @hipMemcpy
44+
// CHECK: llvm.call @hipMemset
45+
// CHECK: llvm.call @hipDeviceSynchronize
46+
// CHECK: llvm.call @hipFree
47+
// CHECK: llvm.call @hipGetLastError
48+
49+
// CHECK-LABEL: llvm.func @test_nvvm_barrier_conversion
50+
// CHECK: llvm.store
51+
// CHECK: rocdl.barrier
52+
// CHECK: llvm.load

0 commit comments

Comments
 (0)