Skip to content

Commit 39edde4

Browse files
committed
Feat: add libm overriding
1 parent 2a5eb69 commit 39edde4

File tree

11 files changed

+157
-24
lines changed

11 files changed

+157
-24
lines changed

crates/cuda_builder/src/lib.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ pub struct CudaBuilder {
107107
///
108108
/// `false` by default.
109109
pub optix: bool,
110+
/// Whether to override calls to [`libm`](https://docs.rs/libm/latest/libm/) with calls to libdevice intrinsics.
111+
///
112+
/// Libm is used by no_std crates for functions such as sin, cos, fabs, etc. However, CUDA provides
113+
/// extremely fast GPU-specific implementations of such functions through `libdevice`. Therefore, the codegen
114+
/// exposes the option to automatically override any calls to libm functions with calls to libdevice functions.
115+
/// However, this means the overriden functions are likely to not be deterministic, so if you rely on strict
116+
/// determinism in things like `rapier`, then it may be helpful to disable such a feature.
117+
///
118+
/// `true` by default.
119+
pub override_libm: bool,
110120
}
111121

112122
impl CudaBuilder {
@@ -125,6 +135,7 @@ impl CudaBuilder {
125135
fma_contraction: true,
126136
emit: None,
127137
optix: false,
138+
override_libm: true,
128139
}
129140
}
130141

@@ -233,6 +244,18 @@ impl CudaBuilder {
233244
self
234245
}
235246

247+
/// Whether to override calls to [`libm`](https://docs.rs/libm/latest/libm/) with calls to libdevice intrinsics.
248+
///
249+
/// Libm is used by no_std crates for functions such as sin, cos, fabs, etc. However, CUDA provides
250+
/// extremely fast GPU-specific implementations of such functions through `libdevice`. Therefore, the codegen
251+
/// exposes the option to automatically override any calls to libm functions with calls to libdevice functions.
252+
/// However, this means the overriden functions are likely to not be deterministic, so if you rely on strict
253+
/// determinism in things like `rapier`, then it may be helpful to disable such a feature.
254+
pub fn override_libm(mut self, override_libm: bool) -> Self {
255+
self.override_libm = override_libm;
256+
self
257+
}
258+
236259
/// Runs rustc to build the codegen and codegens the gpu crate, returning the path of the final
237260
/// ptx file. If [`ptx_file_copy_path`](Self::ptx_file_copy_path) is set, this returns the copied path.
238261
pub fn build(self) -> Result<PathBuf, CudaBuilderError> {
@@ -351,6 +374,10 @@ fn invoke_rustc(builder: &CudaBuilder) -> Result<PathBuf, CudaBuilderError> {
351374
llvm_args.push("-fma=0".to_string());
352375
}
353376

377+
if builder.override_libm {
378+
llvm_args.push("--override-libm".to_string());
379+
}
380+
354381
let llvm_args = llvm_args.join(" ");
355382
if !llvm_args.is_empty() {
356383
rustflags.push(["-Cllvm-args=", &llvm_args].concat());

crates/cust/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ at some point, but now we reconsidered that it may be the wrong choice.
1414
- Remove `GpuBox::as_device_ptr_mut` and `GpuBuffer::as_device_ptr_mut`.
1515
- Change `GpuBox::as_device_ptr` and `GpuBuffer::as_device_ptr` to take `&self` instead of `&mut self`.
1616
- Remove accidentally added `vek` default feature.
17+
- `vek` feature now uses `default-features = false`, this also means `Rgb` and `Rgba` no longer implement `DeviceCopy`.

crates/cust/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ description = "High level bindings to the CUDA Driver API"
1111
cust_raw = { path = "../cust_raw", version = "0.11.2"}
1212
bitflags = "1.2"
1313
cust_derive = { path = "../cust_derive", version = "0.1" }
14-
vek = { version = "0.15.1", optional = true }
14+
vek = { version = "0.15.1", optional = true, default-features = false }
1515

1616
[build-dependencies]
1717
find_cuda_helper = { path = "../find_cuda_helper", version = "0.1" }

crates/cust/src/memory/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ use vek::*;
281281

282282
#[cfg(feature = "vek")]
283283
impl_device_copy_vek! {
284-
Vec2, Vec3, Vec4, Extent2, Extent3, Rgb, Rgba,
284+
Vec2, Vec3, Vec4, Extent2, Extent3,
285285
Mat2, Mat3, Mat4,
286286
CubicBezier2, CubicBezier3,
287287
Quaternion,

crates/rustc_codegen_nvvm/CHANGELOG.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,22 @@ method of codegen for the final steps of the codegen. We no longer lazily-load m
1111
we instead merge all the modules into one then run global DCE on it before giving it to libnvvm.
1212

1313
This means all of the dead code is gone before it gets to the libnvvm stage, drastically lowering the size of
14-
the built PTX and improving codegen performance.
14+
the built PTX and improving codegen performance. `cuda_std` also has a macro `#[externally_visible]` which can
15+
be used if you want to keep a function around for things like linking multiple PTX files together.
16+
17+
### Libm override
18+
19+
The codegen now has the ability to override [`libm`](https://docs.rs/libm/latest/libm/) functions with
20+
[`libdevice`](https://docs.nvidia.com/cuda/libdevice-users-guide/introduction.html#introduction) intrinsics.
21+
22+
Libdevice is a bitcode library shipped with every CUDA SDK installation which provides float routines that
23+
are optimized for the GPU and for specific GPU architectures. However, these routines are hard to use automatically because
24+
no_std math crates typically use libm for float things. So users often ended up with needlessly slow or large PTX files
25+
because they used "emulated" routines.
26+
27+
Now, by default (can be disabled in cuda_builder) the codegen will override libm functions with calls to libdevice automatically.
28+
However, if you rely on libm for determinism, you must disable the overriding, since libdevice is not strictly deterministic.
29+
This also makes PTX much smaller generally, in our example path tracer, it slimmed the PTX file from about `3800` LoC to `2300` LoC.
1530

1631
- Trace-level debug is compiled out for release now, decreasing the size of the codegen dll and improving compile times.
1732

crates/rustc_codegen_nvvm/src/back.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::llvm::{self};
2+
use crate::override_fns::define_or_override_fn;
23
use crate::{builder::Builder, context::CodegenCx, lto::ThinBuffer, LlvmMod, NvvmCodegenBackend};
34
use libc::{c_char, size_t};
45
use rustc_codegen_ssa::back::write::{TargetMachineFactoryConfig, TargetMachineFactoryFn};
@@ -14,6 +15,7 @@ use rustc_data_structures::small_c_str::SmallCStr;
1415
use rustc_errors::{FatalError, Handler};
1516
use rustc_fs_util::path_to_c_string;
1617
use rustc_middle::bug;
18+
use rustc_middle::mir::mono::MonoItem;
1719
use rustc_middle::{dep_graph, ty::TyCtxt};
1820
use rustc_session::config::{self, DebugInfo, OutputType};
1921
use rustc_session::Session;
@@ -263,7 +265,11 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen
263265

264266
// ... and now that we have everything pre-defined, fill out those definitions.
265267
for &(mono_item, _) in &mono_items {
266-
mono_item.define::<Builder<'_, '_, '_>>(&cx);
268+
if let MonoItem::Fn(func) = mono_item {
269+
define_or_override_fn(func, &cx);
270+
} else {
271+
mono_item.define::<Builder<'_, '_, '_>>(&cx);
272+
}
267273
}
268274

269275
// a main function for gpu kernels really makes no sense but

crates/rustc_codegen_nvvm/src/context.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,7 @@ pub(crate) struct CodegenCx<'ll, 'tcx> {
9292
eh_personality: &'ll Value,
9393

9494
pub symbols: Symbols,
95-
96-
// we do not currently use codegen_args before linking, and during linking we reparse
97-
// them because codegencx is not available at link time. However, we keep this so
98-
// it is easier to use them in the future and add args we want to use before linking.
99-
#[allow(dead_code)]
10095
pub codegen_args: CodegenArgs,
101-
10296
// the value of the last call instruction. Needed for return type remapping.
10397
pub last_call_llfn: Cell<Option<&'ll Value>>,
10498
}
@@ -505,25 +499,31 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
505499
}
506500
}
507501

502+
#[derive(Default, Clone)]
508503
pub struct CodegenArgs {
509504
pub nvvm_options: Vec<NvvmOption>,
505+
pub override_libm: bool,
510506
}
511507

512508
impl CodegenArgs {
513509
pub fn from_session(sess: &Session) -> Self {
514-
match Self::parse(&sess.opts.cg.llvm_args) {
515-
Ok(x) => x,
516-
Err(err) => sess.fatal(&format!("Failed to parse codegen args: {}", err)),
517-
}
510+
Self::parse(&sess.opts.cg.llvm_args)
518511
}
519512

520513
// we may want to use rustc's own option parsing facilities to have better errors in the future.
521-
pub fn parse(args: &[String]) -> Result<Self, &'static str> {
522-
let nvvm_options = args
523-
.iter()
524-
.map(|x| NvvmOption::from_str(x))
525-
.collect::<Result<Vec<_>, _>>()?;
526-
Ok(Self { nvvm_options })
514+
pub fn parse(args: &[String]) -> Self {
515+
// TODO: replace this with a "proper" arg parser.
516+
let mut cg_args = Self::default();
517+
518+
for arg in args {
519+
if let Ok(flag) = NvvmOption::from_str(arg) {
520+
cg_args.nvvm_options.push(flag);
521+
} else if arg == "--override-libm" {
522+
cg_args.override_libm = true;
523+
}
524+
}
525+
526+
cg_args
527527
}
528528
}
529529

crates/rustc_codegen_nvvm/src/ctx_intrinsics.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
262262
"__nv_tgamma" |
263263
"__nv_trunc" |
264264
"__nv_y0" |
265-
"__nv_y1" |
266-
"__nv_yn",
265+
"__nv_y1",
267266
fn(t_f64) -> t_f64
268267
);
269268

@@ -316,8 +315,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
316315
"__nv_tgammaf" |
317316
"__nv_truncf" |
318317
"__nv_y0f" |
319-
"__nv_y1f" |
320-
"__nv_ynf",
318+
"__nv_y1f",
321319
fn(t_f32) -> t_f32
322320
);
323321

@@ -408,5 +406,17 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
408406
"__nv_fmaf",
409407
fn(t_f32, t_f32, t_f32) -> t_f32
410408
);
409+
410+
ifn!(
411+
map,
412+
"__nv_yn",
413+
fn(t_i32, t_f64) -> t_f64
414+
);
415+
416+
ifn!(
417+
map,
418+
"__nv_ynf",
419+
fn(t_i32, t_f32) -> t_f32
420+
);
411421
}
412422
}

crates/rustc_codegen_nvvm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ mod llvm;
4444
mod lto;
4545
mod mono_item;
4646
mod nvvm;
47+
mod override_fns;
4748
mod target;
4849
mod ty;
4950

crates/rustc_codegen_nvvm/src/llvm.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,17 @@ pub(crate) fn get_param(llfn: &Value, index: c_uint) -> &Value {
119119
}
120120
}
121121

122+
/// Safe wrapper around `LLVMGetParams`.
123+
pub(crate) fn get_params(llfn: &Value) -> Vec<&Value> {
124+
unsafe {
125+
let count = LLVMCountParams(llfn) as usize;
126+
let mut params = Vec::with_capacity(count);
127+
LLVMGetParams(llfn, params.as_mut_ptr());
128+
params.set_len(count);
129+
params
130+
}
131+
}
132+
122133
/// Safe wrapper for `LLVMGetValueName2` into a byte slice
123134
pub(crate) fn get_value_name(value: &Value) -> &[u8] {
124135
unsafe {

0 commit comments

Comments
 (0)