Skip to content

Commit b1d7223

Browse files
ftynsewsmoses
andauthored
use addrspacecast insteadof bitcast for pointers (#1281)
* use addrspacecast insteadof bitcast for pointers We used to be bitcating pointers because of typed pointers, which are now gone, but we are getting address space mismatches. Repurpose bitcast insertion for address casts. Not sure this is the right approach or whether we should change the function signature to accept pointers in another address spaces. * fmt --------- Co-authored-by: William S. Moses <[email protected]>
1 parent e0cdb7f commit b1d7223

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,7 +1891,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
18911891

18921892
auto addressOfWrapper =
18931893
ctorBuilder.create<LLVM::AddressOfOp>(loc, fatBinWrapper);
1894-
auto bitcastOfWrapper = ctorBuilder.create<LLVM::BitcastOp>(
1894+
auto bitcastOfWrapper = ctorBuilder.create<LLVM::AddrSpaceCastOp>(
18951895
loc, llvmPointerType, addressOfWrapper);
18961896

18971897
auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn(
@@ -1954,8 +1954,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
19541954
rewriter.create<LLVM::ReturnOp>(loc, ValueRange());
19551955
}
19561956
auto aoo = ctorBuilder.create<LLVM::AddressOfOp>(loc, stub);
1957-
auto bitcast =
1958-
ctorBuilder.create<LLVM::BitcastOp>(loc, llvmPointerType, aoo);
1957+
auto bitcast = ctorBuilder.create<LLVM::AddrSpaceCastOp>(
1958+
loc, llvmPointerType, aoo);
19591959

19601960
Type tys[] = {llvmPointerType, llvmPointerType, llvmPointerType,
19611961
llvmPointerType, llvmInt32Type, llvmPointerType,
@@ -2003,8 +2003,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
20032003
auto stub = moduleOp.lookupSymbol<LLVM::GlobalOp>(g.getName());
20042004
assert(stub);
20052005
auto aoo = ctorBuilder.create<LLVM::AddressOfOp>(loc, stub);
2006-
auto bitcast =
2007-
ctorBuilder.create<LLVM::BitcastOp>(loc, llvmPointerType, aoo);
2006+
auto bitcast = ctorBuilder.create<LLVM::AddrSpaceCastOp>(
2007+
loc, llvmPointerType, aoo);
20082008
auto globalTy = stub.getGlobalType();
20092009
// TODO This should actually be the GPUModuleOp's data layout I
20102010
// believe, there were problems with assigning the data layout to
@@ -2094,7 +2094,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
20942094
SymbolTable::lookupSymbolIn(moduleOp, funcStubName));
20952095
assert(!!stub);
20962096
auto aoo = rewriter.create<LLVM::AddressOfOp>(loc, stub);
2097-
auto bitcast = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, aoo);
2097+
auto bitcast =
2098+
rewriter.create<LLVM::AddrSpaceCastOp>(loc, llvmPointerType, aoo);
20982099

20992100
Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
21002101
auto nullpointer = rewriter.create<LLVM::ZeroOp>(loc, llvmPointerType);
@@ -3109,7 +3110,8 @@ struct ReconcileUnrealizedPointerCasts
31093110
if (!(isa<LLVM::LLVMPointerType>(inputTy) &&
31103111
isa<LLVM::LLVMPointerType>(outputTy)))
31113112
return failure();
3112-
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(ucc, outputTy, inputs[0]);
3113+
rewriter.replaceOpWithNewOp<LLVM::AddrSpaceCastOp>(ucc, outputTy,
3114+
inputs[0]);
31133115
return success();
31143116
}
31153117
};

0 commit comments

Comments
 (0)