From b9228eaa796f0f24e23e685a62b58a0b17d6449f Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Sun, 24 Aug 2025 16:16:13 -0700 Subject: [PATCH] Add PTX calling conventions to generated LLVM IR Fixes a TODO. Shouldn't affect anything but why not. --- crates/rustc_codegen_nvvm/src/context.rs | 11 +++++++---- crates/rustc_codegen_nvvm/src/llvm.rs | 14 ++++++++++++++ crates/rustc_codegen_nvvm/src/mono_item.rs | 3 +++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/crates/rustc_codegen_nvvm/src/context.rs b/crates/rustc_codegen_nvvm/src/context.rs index 253bb457..d6487631 100644 --- a/crates/rustc_codegen_nvvm/src/context.rs +++ b/crates/rustc_codegen_nvvm/src/context.rs @@ -317,8 +317,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { } } - /// Declare a function. All functions use the default ABI, NVVM ignores any calling convention markers. - /// All functions calls are generated according to the PTX calling convention. + /// Declare a function with appropriate PTX calling conventions. /// pub fn declare_fn( &self, @@ -332,8 +331,12 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { trace!("Declaring function `{}` with ty `{:?}`", name, ty); - // TODO(RDambrosio016): we should probably still generate accurate calling conv for functions - // just to make it easier to debug IR and/or make it more compatible with compiling using llvm + // Set PTX device calling convention for all functions declared here. + // Kernel functions will have their calling convention overridden in mono_item.rs + unsafe { + llvm::LLVMSetFunctionCallConv(llfn, llvm::PtxCallConv::Device as u32); + } + llvm::SetUnnamedAddress(llfn, llvm::UnnamedAddr::Global); if let Some(abi) = fn_abi { abi.apply_attrs_llfn(self, llfn); diff --git a/crates/rustc_codegen_nvvm/src/llvm.rs b/crates/rustc_codegen_nvvm/src/llvm.rs index a0243eed..3eb75cd5 100644 --- a/crates/rustc_codegen_nvvm/src/llvm.rs +++ b/crates/rustc_codegen_nvvm/src/llvm.rs @@ -206,6 +206,20 @@ pub(crate) enum Visibility { Protected = 2, } +/// PTX/NVPTX calling conventions from LLVM +/// See: +/// +/// While NVVM doesn't strictly require these calling conventions to be set +/// (it generates PTX according to its own rules), we set them anyway to +/// make the generated LLVM IR more accurate and easier to debug. +#[repr(u32)] +pub(crate) enum PtxCallConv { + /// PTX kernel calling convention + Kernel = 71, + /// PTX device calling convention + Device = 72, +} + /// LLVMUnnamedAddr #[repr(C)] pub(crate) enum UnnamedAddr { diff --git a/crates/rustc_codegen_nvvm/src/mono_item.rs b/crates/rustc_codegen_nvvm/src/mono_item.rs index 1145668f..95ccafca 100644 --- a/crates/rustc_codegen_nvvm/src/mono_item.rs +++ b/crates/rustc_codegen_nvvm/src/mono_item.rs @@ -98,6 +98,9 @@ impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'_, 'tcx> { // to nvvm.annotations per the nvvm ir docs. if nvvm_attrs.kernel { trace!("Marking function `{:?}` as a kernel", symbol_name); + llvm::LLVMSetFunctionCallConv(lldecl, llvm::PtxCallConv::Kernel as u32); + + // Add kernel metadata for NVVM let kernel = llvm::LLVMMDStringInContext(self.llcx, "kernel".as_ptr().cast(), 6); let mdvals = &[lldecl, kernel, self.const_i32(1)]; let node =