3131#include " llvm/ADT/SmallPtrSet.h"
3232#include " llvm/ADT/StringRef.h"
3333#include " llvm/Support/DebugLog.h"
34+ #include < llvm/ADT/SmallVector.h>
35+ #include < mlir/Dialect/LLVMIR/LLVMDialect.h>
3436
3537#include " Enzyme/MLIR/Dialect/Ops.h"
3638#include " Enzyme/MLIR/Passes/Passes.h"
@@ -45,6 +47,7 @@ namespace enzyme {
4547#define GEN_PASS_DEF_PARALLELLOWER
4648#define GEN_PASS_DEF_FIXGPUFUNC
4749#define GEN_PASS_DEF_STRIPGPUINFO
50+ #define GEN_PASS_DEF_CONVERTCUDARTTOHIPRT
4851#include " src/enzyme_ad/jax/Passes/Passes.h.inc"
4952} // namespace enzyme
5053} // namespace mlir
@@ -109,11 +112,13 @@ struct ConvertCudaRTtoCPU : public ConvertCudaRTtoCPUBase<ConvertCudaRTtoCPU> {
109112struct ConvertCudaRTtoGPU : public ConvertCudaRTtoGPUBase<ConvertCudaRTtoGPU> {
110113 void runOnOperation() override;
111114};
115+ */
116+
112117struct ConvertCudaRTtoHipRT
113- : public ConvertCudaRTtoHipRTBase<ConvertCudaRTtoHipRT> {
118+ : public enzyme::impl:: ConvertCudaRTtoHipRTBase<ConvertCudaRTtoHipRT> {
114119 void runOnOperation () override ;
115120};
116- */
121+
117122struct FixGPUFunc : public enzyme ::impl::FixGPUFuncBase<FixGPUFunc> {
118123 using FixGPUFuncBase::FixGPUFuncBase;
119124 void runOnOperation () override ;
@@ -1330,18 +1335,19 @@ static void setCallee(LLVM::CallOp call, StringRef symName) {
13301335 call.setCallee (symName);
13311336}
13321337template <typename CallOpTy, typename FuncOpTy>
1333- void replaceCallOp (ModuleOp m, CallOpTy call, llvm::StringRef callee) {
1334- auto loc = call->getLoc ();
1335- OpBuilder moduleBuilder = OpBuilder::atBlockEnd (m.getBody ());
1338+ void replaceCallOp (ModuleOp m, CallOpTy call, llvm::StringRef callee,
1339+ SmallPtrSetImpl<Operation *> &toErase) {
13361340 OpBuilder callBuilder (call);
13371341 auto funcOp = m.lookupSymbol <FuncOpTy>(callee);
13381342 if (isHipCallEquivalent (callee)) {
13391343 assert (funcOp);
13401344 auto hipName = getHipName (callee);
13411345 if (!m.lookupSymbol <FuncOpTy>(hipName)) {
1346+ OpBuilder moduleBuilder (funcOp.getOperation ());
13421347 auto hipFuncOp =
13431348 cast<FuncOpTy>(moduleBuilder.clone (*funcOp.getOperation ()));
13441349 hipFuncOp.setSymName (hipName);
1350+ toErase.insert (funcOp.getOperation ());
13451351 }
13461352 setCallee (call, hipName);
13471353 } else {
@@ -1351,24 +1357,30 @@ void replaceCallOp(ModuleOp m, CallOpTy call, llvm::StringRef callee) {
13511357 }
13521358}
13531359
1354- #if 0
13551360void ConvertCudaRTtoHipRT::runOnOperation () {
1361+ SmallPtrSet<Operation *, 8 > toErase;
1362+
13561363 getOperation ().walk ([&](LLVM::CallOp call) {
13571364 if (!call.getCallee ())
13581365 return ;
13591366 auto name = *call.getCallee ();
13601367 if (!isCudartCall (name))
13611368 return ;
1362- replaceCallOp<LLVM::CallOp, LLVM::LLVMFuncOp>(getOperation(), call, name);
1369+ replaceCallOp<LLVM::CallOp, LLVM::LLVMFuncOp>(getOperation (), call, name,
1370+ toErase);
13631371 });
13641372
13651373 getOperation ().walk ([&](CallOp call) {
13661374 auto name = call.getCallee ();
13671375 if (!isCudartCall (name))
13681376 return ;
1369- replaceCallOp<CallOp, func::FuncOp>(getOperation(), call, name);
1377+ replaceCallOp<CallOp, func::FuncOp>(getOperation (), call, name, toErase );
13701378 });
13711379
1380+ // Erase old CUDA function declarations after all calls are updated
1381+ for (Operation *op : toErase)
1382+ op->erase ();
1383+
13721384 OpBuilder builder (&getContext ());
13731385 getOperation ().walk ([&](mlir::NVVM::Barrier0Op op) {
13741386 builder.setInsertionPoint (op);
@@ -1377,6 +1389,7 @@ void ConvertCudaRTtoHipRT::runOnOperation() {
13771389 });
13781390}
13791391
1392+ #if 0
13801393void ConvertCudaRTtoGPU::runOnOperation() {
13811394 std::function<void(Operation * call, llvm::StringRef callee)> replaceWithOp =
13821395 [&](Operation *call, llvm::StringRef callee) {
0 commit comments