Skip to content

Commit e97db1e

Browse files
committed
Feat: WIP module rework (cubins dont work lol)
1 parent 8d264ad commit e97db1e

File tree

7 files changed

+213
-2
lines changed

7 files changed

+213
-2
lines changed

crates/cuda_std/src/atomic/intrinsics.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub unsafe fn fence_acqrel_system() {
4949
asm!("fence.acq_rel.sys;");
5050
}
5151

52+
#[allow(unused_macros)]
5253
macro_rules! load_scope {
5354
(volatile, $scope:ident) => {
5455
""

crates/cust/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ cust_derive = { path = "../cust_derive", version = "0.1" }
1919
num-complex = { version = "0.4", optional = true }
2020
vek = { version = "0.15.1", optional = true, default-features = false }
2121
bytemuck = { version = "1.7.3", optional = true }
22+
goblin = { version = "0.4.3", default-features = false, features = ["elf32", "elf64", "std", "endian_fd"] }
2223

2324
[features]
2425
default = ["bytemuck"]

crates/cust/resources/add.cubin

2.19 KB
Binary file not shown.

crates/cust/resources/add.fatbin

2.92 KB
Binary file not shown.

crates/cust/src/module.rs

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::ffi::{c_void, CStr, CString};
88
use std::fmt;
99
use std::marker::PhantomData;
1010
use std::mem;
11+
use std::os::raw::c_uint;
1112
use std::path::Path;
1213
use std::ptr;
1314

@@ -17,6 +18,114 @@ pub struct Module {
1718
inner: cuda::CUmodule,
1819
}
1920

21+
/// The possible optimization levels when JIT compiling a PTX module. `O4` by default (most optimized).
22+
#[repr(u32)]
23+
#[derive(Debug, Clone, Copy, PartialEq)]
24+
pub enum OptLevel {
25+
O0 = 0,
26+
O1 = 1,
27+
O2 = 2,
28+
O3 = 3,
29+
O4 = 4,
30+
}
31+
32+
/// The possible targets when JIT compiling a PTX module.
33+
#[non_exhaustive]
34+
#[repr(u32)]
35+
#[derive(Debug, Clone, Copy, PartialEq)]
36+
pub enum JitTarget {
37+
Compute20 = 20,
38+
Compute21 = 21,
39+
Compute30 = 30,
40+
Compute32 = 32,
41+
Compute35 = 35,
42+
Compute37 = 37,
43+
Compute50 = 50,
44+
Compute52 = 52,
45+
Compute53 = 53,
46+
Compute60 = 60,
47+
Compute61 = 61,
48+
Compute62 = 62,
49+
Compute70 = 70,
50+
Compute72 = 72,
51+
Compute75 = 75,
52+
Compute80 = 80,
53+
Compute86 = 86,
54+
}
55+
56+
/// How to handle cases where a loaded module's data does not contain an exact match for the
57+
/// specified architecture.
58+
#[repr(u32)]
59+
#[derive(Debug, Clone, Copy, PartialEq)]
60+
pub enum JitFallback {
61+
/// Prefer to compile PTX if present if an exact binary match is not found.
62+
PreferPtx = 0,
63+
/// Prefer to fall back to a compatible binary code match if exact match is not found.
64+
/// This means the driver may pick binary code for `7.0` if your device is `7.2` for example.
65+
PreferCompatibleBinary = 1,
66+
}
67+
68+
/// Different options that could be applied when loading a module.
69+
#[non_exhaustive]
70+
#[derive(Debug, Clone, Copy, PartialEq)]
71+
pub enum ModuleJitOption {
72+
/// Specifies the maximum amount of registers any compiled PTX is allowed to use.
73+
MaxRegisters(u32),
74+
/// Specifies the optimization level for the JIT compiler.
75+
OptLevel(OptLevel),
76+
/// Determines the PTX target from the current context's architecture. Cannot be combined with
77+
/// [`ModuleJitOption::Target`].
78+
DetermineTargetFromContext,
79+
/// Specifies the target for the JIT compiler. Cannot be combined with [`ModuleJitOption::DetermineTargetFromContext`].
80+
Target(JitTarget),
81+
/// Specifies how to handle cases where a loaded module's data does not have an exact match for the specified
82+
/// architecture.
83+
Fallback(JitFallback),
84+
/// Generates debug info in the compiled binary.
85+
GenenerateDebugInfo(bool),
86+
/// Generates line info in the compiled binary.
87+
GenerateLineInfo(bool),
88+
}
89+
90+
impl ModuleJitOption {
91+
pub fn into_raw(opts: &[Self]) -> (Vec<cuda::CUjit_option>, Vec<*mut c_void>) {
92+
let mut raw_opts = Vec::with_capacity(opts.len());
93+
let mut raw_vals = Vec::with_capacity(opts.len());
94+
for opt in opts {
95+
match opt {
96+
Self::MaxRegisters(regs) => {
97+
raw_opts.push(cuda::CUjit_option::CU_JIT_MAX_REGISTERS);
98+
raw_vals.push(regs as *const u32 as *mut _);
99+
}
100+
Self::OptLevel(level) => {
101+
raw_opts.push(cuda::CUjit_option::CU_JIT_OPTIMIZATION_LEVEL);
102+
raw_vals.push(level as *const OptLevel as *mut _);
103+
}
104+
Self::DetermineTargetFromContext => {
105+
raw_opts.push(cuda::CUjit_option::CU_JIT_TARGET_FROM_CUCONTEXT);
106+
}
107+
Self::Target(target) => {
108+
raw_opts.push(cuda::CUjit_option::CU_JIT_TARGET);
109+
raw_vals.push(target as *const JitTarget as *mut _);
110+
}
111+
Self::Fallback(fallback) => {
112+
raw_opts.push(cuda::CUjit_option::CU_JIT_FALLBACK_STRATEGY);
113+
raw_vals.push(fallback as *const JitFallback as *mut _);
114+
}
115+
Self::GenenerateDebugInfo(gen) => {
116+
raw_opts.push(cuda::CUjit_option::CU_JIT_GENERATE_DEBUG_INFO);
117+
raw_vals.push(gen as *const bool as *mut _);
118+
}
119+
Self::GenerateLineInfo(gen) => {
120+
raw_opts.push(cuda::CUjit_option::CU_JIT_GENERATE_LINE_INFO);
121+
raw_vals.push(gen as *const bool as *mut _)
122+
}
123+
}
124+
}
125+
(raw_opts, raw_vals)
126+
}
127+
}
128+
20129
#[cfg(unix)]
21130
fn path_to_bytes<P: AsRef<Path>>(path: P) -> Vec<u8> {
22131
use std::os::unix::ffi::OsStrExt;
@@ -66,12 +175,106 @@ impl Module {
66175
}
67176
}
68177

178+
/// Creates a new module by loading a fatbin (fat binary) file.
179+
///
180+
/// Fatbinary files are files that contain multiple ptx or cubin files. The driver will choose already-built
181+
/// cubin if it is present, and otherwise JIT compile any PTX in the file to cubin.
182+
///
183+
/// # Example
184+
///
185+
/// ```
186+
/// # use cust::*;
187+
/// # use std::error::Error;
188+
/// # fn main() -> Result<(), Box<dyn Error>> {
189+
/// # let _ctx = quick_init()?;
190+
/// use cust::module::Module;
191+
/// let fatbin_bytes = std::fs::read("./resources/add.cubin")?;
192+
/// assert!(fatbin_bytes.contains(&0));
193+
/// let module = Module::from_cubin(&fatbin_bytes, &[])?;
194+
/// # Ok(())
195+
/// # }
196+
/// ```
197+
pub fn from_fatbin<T: AsRef<[u8]>>(
198+
bytes: T,
199+
options: &[ModuleJitOption],
200+
) -> CudaResult<Module> {
201+
let mut bytes = bytes.as_ref().to_vec();
202+
bytes.push(0);
203+
// fatbins are just ELF files like cubins, and cuModuleLoadDataEx accepts ptx, cubin, and fatbin.
204+
// We just make the distinction in case we want to do anything extra in the future. As well
205+
// as keep things explicit to anyone reading the code.
206+
Self::from_cubin(bytes, options)
207+
}
208+
209+
pub unsafe fn from_fatbin_unchecked<T: AsRef<[u8]>>(
210+
bytes: T,
211+
options: &[ModuleJitOption],
212+
) -> CudaResult<Module> {
213+
Self::from_cubin_unchecked(bytes, options)
214+
}
215+
216+
pub fn from_cubin<T: AsRef<[u8]>>(bytes: T, options: &[ModuleJitOption]) -> CudaResult<Module> {
217+
let bytes = bytes.as_ref();
218+
goblin::elf::Elf::parse(bytes).expect("Cubin/Fatbin was not valid ELF!");
219+
// SAFETY: we verified the bytes were valid ELF
220+
unsafe { Self::from_cubin_unchecked(bytes, options) }
221+
}
222+
223+
pub unsafe fn from_cubin_unchecked<T: AsRef<[u8]>>(
224+
bytes: T,
225+
options: &[ModuleJitOption],
226+
) -> CudaResult<Module> {
227+
let bytes = bytes.as_ref();
228+
let mut module = Module {
229+
inner: ptr::null_mut(),
230+
};
231+
let (mut options, mut option_values) = ModuleJitOption::into_raw(options);
232+
cuda::cuModuleLoadDataEx(
233+
&mut module.inner as *mut cuda::CUmodule,
234+
bytes.as_ptr() as *const c_void,
235+
options.len() as c_uint,
236+
options.as_mut_ptr(),
237+
option_values.as_mut_ptr(),
238+
)
239+
.to_result()?;
240+
Ok(module)
241+
}
242+
243+
pub fn from_ptx_cstr(cstr: &CStr, options: &[ModuleJitOption]) -> CudaResult<Module> {
244+
unsafe {
245+
let mut module = Module {
246+
inner: ptr::null_mut(),
247+
};
248+
let (mut options, mut option_values) = ModuleJitOption::into_raw(options);
249+
cuda::cuModuleLoadDataEx(
250+
&mut module.inner as *mut cuda::CUmodule,
251+
cstr.as_ptr() as *const c_void,
252+
options.len() as c_uint,
253+
options.as_mut_ptr(),
254+
option_values.as_mut_ptr(),
255+
)
256+
.to_result()?;
257+
Ok(module)
258+
}
259+
}
260+
261+
pub fn from_ptx<T: AsRef<str>>(string: T, options: &[ModuleJitOption]) -> CudaResult<Module> {
262+
let cstr = CString::new(string.as_ref())
263+
.expect("string given to Module::from_str contained nul bytes");
264+
Self::from_ptx_cstr(cstr.as_c_str(), options)
265+
}
266+
69267
/// Load a module from a normal (rust) string, implicitly making it into
70268
/// a cstring.
269+
#[deprecated(
270+
since = "0.3.0",
271+
note = "from_str was too generic of a name, use from_ptx instead, passing an empty slice of options (usually)"
272+
)]
71273
#[allow(clippy::should_implement_trait)]
72274
pub fn from_str<T: AsRef<str>>(string: T) -> CudaResult<Module> {
73275
let cstr = CString::new(string.as_ref())
74276
.expect("string given to Module::from_str contained nul bytes");
277+
#[allow(deprecated)]
75278
Self::load_from_string(cstr.as_c_str())
76279
}
77280

@@ -98,6 +301,12 @@ impl Module {
98301
/// # Ok(())
99302
/// # }
100303
/// ```
304+
#[deprecated(
305+
since = "0.3.0",
306+
note = "load_from_string was an inconsistent name with inconsistent params, use from_ptx/from_ptx_cstr, passing
307+
an empty slice of options (usually)
308+
"
309+
)]
101310
pub fn load_from_string(image: &CStr) -> CudaResult<Module> {
102311
unsafe {
103312
let mut module = Module {

examples/cuda/cpu/add/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ fn main() -> Result<(), Box<dyn Error>> {
2222

2323
// Make the CUDA module, modules just house the GPU code for the kernels we created.
2424
// they can be made from PTX code, cubins, or fatbins.
25-
let module = Module::from_str(PTX)?;
25+
let module = Module::from_ptx(PTX, &[])?;
2626

2727
// make a CUDA stream to issue calls to. You can think of this as an OS thread but for dispatching
2828
// GPU calls.

examples/cuda/cpu/path_tracer/src/cuda/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl CudaRenderer {
4747

4848
let optix_context = OptixContext::new(&context).unwrap();
4949

50-
let module = Module::from_str(PTX)?;
50+
let module = Module::from_ptx(PTX, &[])?;
5151
let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
5252
let mut denoiser =
5353
Denoiser::new(&optix_context, DenoiserModelKind::Ldr, Default::default()).unwrap();

0 commit comments

Comments
 (0)