diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 14e5234d2038..17f9a5a34c79 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -262,9 +262,16 @@ void CodeGen_PTX_Dev::visit(const Call *op) { auto fence_type_ptr = as_const_int(op->args[0]); internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n"; - llvm::Function *barrier0 = module->getFunction("llvm.nvvm.barrier0"); - internal_assert(barrier0) << "Could not find PTX barrier intrinsic (llvm.nvvm.barrier0)\n"; - builder->CreateCall(barrier0); + llvm::Function *barrier; + if ((barrier = module->getFunction("llvm.nvvm.barrier.cta.sync.aligned.all")) && barrier->getIntrinsicID() != 0) { + // LLVM 20.1.6 and above: https://github.com/llvm/llvm-project/pull/140615 + builder->CreateCall(barrier, builder->getInt32(0)); + } else if ((barrier = module->getFunction("llvm.nvvm.barrier0")) && barrier->getIntrinsicID() != 0) { + // LLVM 21.1.5 and below: Testing for llvm.nvvm.barrier0 can be removed once we drop support for LLVM 20 + builder->CreateCall(barrier); + } else { + internal_error << "Could not find PTX barrier intrinsic llvm.nvvm.barrier0 nor llvm.nvvm.barrier.cta.sync.aligned.all\n"; + } value = ConstantInt::get(i32_t, 0); return; } diff --git a/src/runtime/ptx_dev.ll b/src/runtime/ptx_dev.ll index 9cefaa53ec5b..e29574c74e91 100644 --- a/src/runtime/ptx_dev.ll +++ b/src/runtime/ptx_dev.ll @@ -1,4 +1,11 @@ -declare void @llvm.nvvm.barrier0() +; The two forward declared intrinsics below refer to the same thing. +; LLVM 20.1.6 introduced a new naming scheme for these intrinsics +; We have to declare both, such that we can access them from the Module's +; getFunction(), but one of those will map to an intrinsic, which we +; will use to determine which intrinsic is supported by LLVM. +declare void @llvm.nvvm.barrier0() ; LLVM <=20.1.5 +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) ; LLVM >=20.1.6 + declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()