@@ -8,6 +8,7 @@ use std::ffi::{c_void, CStr, CString};
8
8
use std:: fmt;
9
9
use std:: marker:: PhantomData ;
10
10
use std:: mem;
11
+ use std:: os:: raw:: c_uint;
11
12
use std:: path:: Path ;
12
13
use std:: ptr;
13
14
@@ -17,6 +18,114 @@ pub struct Module {
17
18
inner : cuda:: CUmodule ,
18
19
}
19
20
21
+ /// The possible optimization levels when JIT compiling a PTX module. `O4` by default (most optimized).
22
+ #[ repr( u32 ) ]
23
+ #[ derive( Debug , Clone , Copy , PartialEq ) ]
24
+ pub enum OptLevel {
25
+ O0 = 0 ,
26
+ O1 = 1 ,
27
+ O2 = 2 ,
28
+ O3 = 3 ,
29
+ O4 = 4 ,
30
+ }
31
+
32
+ /// The possible targets when JIT compiling a PTX module.
33
+ #[ non_exhaustive]
34
+ #[ repr( u32 ) ]
35
+ #[ derive( Debug , Clone , Copy , PartialEq ) ]
36
+ pub enum JitTarget {
37
+ Compute20 = 20 ,
38
+ Compute21 = 21 ,
39
+ Compute30 = 30 ,
40
+ Compute32 = 32 ,
41
+ Compute35 = 35 ,
42
+ Compute37 = 37 ,
43
+ Compute50 = 50 ,
44
+ Compute52 = 52 ,
45
+ Compute53 = 53 ,
46
+ Compute60 = 60 ,
47
+ Compute61 = 61 ,
48
+ Compute62 = 62 ,
49
+ Compute70 = 70 ,
50
+ Compute72 = 72 ,
51
+ Compute75 = 75 ,
52
+ Compute80 = 80 ,
53
+ Compute86 = 86 ,
54
+ }
55
+
56
+ /// How to handle cases where a loaded module's data does not contain an exact match for the
57
+ /// specified architecture.
58
+ #[ repr( u32 ) ]
59
+ #[ derive( Debug , Clone , Copy , PartialEq ) ]
60
+ pub enum JitFallback {
61
+ /// Prefer to compile PTX if present if an exact binary match is not found.
62
+ PreferPtx = 0 ,
63
+ /// Prefer to fall back to a compatible binary code match if exact match is not found.
64
+ /// This means the driver may pick binary code for `7.0` if your device is `7.2` for example.
65
+ PreferCompatibleBinary = 1 ,
66
+ }
67
+
68
+ /// Different options that could be applied when loading a module.
69
+ #[ non_exhaustive]
70
+ #[ derive( Debug , Clone , Copy , PartialEq ) ]
71
+ pub enum ModuleJitOption {
72
+ /// Specifies the maximum amount of registers any compiled PTX is allowed to use.
73
+ MaxRegisters ( u32 ) ,
74
+ /// Specifies the optimization level for the JIT compiler.
75
+ OptLevel ( OptLevel ) ,
76
+ /// Determines the PTX target from the current context's architecture. Cannot be combined with
77
+ /// [`ModuleJitOption::Target`].
78
+ DetermineTargetFromContext ,
79
+ /// Specifies the target for the JIT compiler. Cannot be combined with [`ModuleJitOption::DetermineTargetFromContext`].
80
+ Target ( JitTarget ) ,
81
+ /// Specifies how to handle cases where a loaded module's data does not have an exact match for the specified
82
+ /// architecture.
83
+ Fallback ( JitFallback ) ,
84
+ /// Generates debug info in the compiled binary.
85
+ GenenerateDebugInfo ( bool ) ,
86
+ /// Generates line info in the compiled binary.
87
+ GenerateLineInfo ( bool ) ,
88
+ }
89
+
90
+ impl ModuleJitOption {
91
+ pub fn into_raw ( opts : & [ Self ] ) -> ( Vec < cuda:: CUjit_option > , Vec < * mut c_void > ) {
92
+ let mut raw_opts = Vec :: with_capacity ( opts. len ( ) ) ;
93
+ let mut raw_vals = Vec :: with_capacity ( opts. len ( ) ) ;
94
+ for opt in opts {
95
+ match opt {
96
+ Self :: MaxRegisters ( regs) => {
97
+ raw_opts. push ( cuda:: CUjit_option :: CU_JIT_MAX_REGISTERS ) ;
98
+ raw_vals. push ( regs as * const u32 as * mut _ ) ;
99
+ }
100
+ Self :: OptLevel ( level) => {
101
+ raw_opts. push ( cuda:: CUjit_option :: CU_JIT_OPTIMIZATION_LEVEL ) ;
102
+ raw_vals. push ( level as * const OptLevel as * mut _ ) ;
103
+ }
104
+ Self :: DetermineTargetFromContext => {
105
+ raw_opts. push ( cuda:: CUjit_option :: CU_JIT_TARGET_FROM_CUCONTEXT ) ;
106
+ }
107
+ Self :: Target ( target) => {
108
+ raw_opts. push ( cuda:: CUjit_option :: CU_JIT_TARGET ) ;
109
+ raw_vals. push ( target as * const JitTarget as * mut _ ) ;
110
+ }
111
+ Self :: Fallback ( fallback) => {
112
+ raw_opts. push ( cuda:: CUjit_option :: CU_JIT_FALLBACK_STRATEGY ) ;
113
+ raw_vals. push ( fallback as * const JitFallback as * mut _ ) ;
114
+ }
115
+ Self :: GenenerateDebugInfo ( gen) => {
116
+ raw_opts. push ( cuda:: CUjit_option :: CU_JIT_GENERATE_DEBUG_INFO ) ;
117
+ raw_vals. push ( gen as * const bool as * mut _ ) ;
118
+ }
119
+ Self :: GenerateLineInfo ( gen) => {
120
+ raw_opts. push ( cuda:: CUjit_option :: CU_JIT_GENERATE_LINE_INFO ) ;
121
+ raw_vals. push ( gen as * const bool as * mut _ )
122
+ }
123
+ }
124
+ }
125
+ ( raw_opts, raw_vals)
126
+ }
127
+ }
128
+
20
129
#[ cfg( unix) ]
21
130
fn path_to_bytes < P : AsRef < Path > > ( path : P ) -> Vec < u8 > {
22
131
use std:: os:: unix:: ffi:: OsStrExt ;
@@ -66,12 +175,106 @@ impl Module {
66
175
}
67
176
}
68
177
178
+ /// Creates a new module by loading a fatbin (fat binary) file.
179
+ ///
180
+ /// Fatbinary files are files that contain multiple ptx or cubin files. The driver will choose already-built
181
+ /// cubin if it is present, and otherwise JIT compile any PTX in the file to cubin.
182
+ ///
183
+ /// # Example
184
+ ///
185
+ /// ```
186
+ /// # use cust::*;
187
+ /// # use std::error::Error;
188
+ /// # fn main() -> Result<(), Box<dyn Error>> {
189
+ /// # let _ctx = quick_init()?;
190
+ /// use cust::module::Module;
191
+ /// let fatbin_bytes = std::fs::read("./resources/add.cubin")?;
192
+ /// assert!(fatbin_bytes.contains(&0));
193
+ /// let module = Module::from_cubin(&fatbin_bytes, &[])?;
194
+ /// # Ok(())
195
+ /// # }
196
+ /// ```
197
+ pub fn from_fatbin < T : AsRef < [ u8 ] > > (
198
+ bytes : T ,
199
+ options : & [ ModuleJitOption ] ,
200
+ ) -> CudaResult < Module > {
201
+ let mut bytes = bytes. as_ref ( ) . to_vec ( ) ;
202
+ bytes. push ( 0 ) ;
203
+ // fatbins are just ELF files like cubins, and cuModuleLoadDataEx accepts ptx, cubin, and fatbin.
204
+ // We just make the distinction in case we want to do anything extra in the future. As well
205
+ // as keep things explicit to anyone reading the code.
206
+ Self :: from_cubin ( bytes, options)
207
+ }
208
+
209
+ pub unsafe fn from_fatbin_unchecked < T : AsRef < [ u8 ] > > (
210
+ bytes : T ,
211
+ options : & [ ModuleJitOption ] ,
212
+ ) -> CudaResult < Module > {
213
+ Self :: from_cubin_unchecked ( bytes, options)
214
+ }
215
+
216
+ pub fn from_cubin < T : AsRef < [ u8 ] > > ( bytes : T , options : & [ ModuleJitOption ] ) -> CudaResult < Module > {
217
+ let bytes = bytes. as_ref ( ) ;
218
+ goblin:: elf:: Elf :: parse ( bytes) . expect ( "Cubin/Fatbin was not valid ELF!" ) ;
219
+ // SAFETY: we verified the bytes were valid ELF
220
+ unsafe { Self :: from_cubin_unchecked ( bytes, options) }
221
+ }
222
+
223
+ pub unsafe fn from_cubin_unchecked < T : AsRef < [ u8 ] > > (
224
+ bytes : T ,
225
+ options : & [ ModuleJitOption ] ,
226
+ ) -> CudaResult < Module > {
227
+ let bytes = bytes. as_ref ( ) ;
228
+ let mut module = Module {
229
+ inner : ptr:: null_mut ( ) ,
230
+ } ;
231
+ let ( mut options, mut option_values) = ModuleJitOption :: into_raw ( options) ;
232
+ cuda:: cuModuleLoadDataEx (
233
+ & mut module. inner as * mut cuda:: CUmodule ,
234
+ bytes. as_ptr ( ) as * const c_void ,
235
+ options. len ( ) as c_uint ,
236
+ options. as_mut_ptr ( ) ,
237
+ option_values. as_mut_ptr ( ) ,
238
+ )
239
+ . to_result ( ) ?;
240
+ Ok ( module)
241
+ }
242
+
243
+ pub fn from_ptx_cstr ( cstr : & CStr , options : & [ ModuleJitOption ] ) -> CudaResult < Module > {
244
+ unsafe {
245
+ let mut module = Module {
246
+ inner : ptr:: null_mut ( ) ,
247
+ } ;
248
+ let ( mut options, mut option_values) = ModuleJitOption :: into_raw ( options) ;
249
+ cuda:: cuModuleLoadDataEx (
250
+ & mut module. inner as * mut cuda:: CUmodule ,
251
+ cstr. as_ptr ( ) as * const c_void ,
252
+ options. len ( ) as c_uint ,
253
+ options. as_mut_ptr ( ) ,
254
+ option_values. as_mut_ptr ( ) ,
255
+ )
256
+ . to_result ( ) ?;
257
+ Ok ( module)
258
+ }
259
+ }
260
+
261
+ pub fn from_ptx < T : AsRef < str > > ( string : T , options : & [ ModuleJitOption ] ) -> CudaResult < Module > {
262
+ let cstr = CString :: new ( string. as_ref ( ) )
263
+ . expect ( "string given to Module::from_str contained nul bytes" ) ;
264
+ Self :: from_ptx_cstr ( cstr. as_c_str ( ) , options)
265
+ }
266
+
69
267
/// Load a module from a normal (rust) string, implicitly making it into
70
268
/// a cstring.
269
+ #[ deprecated(
270
+ since = "0.3.0" ,
271
+ note = "from_str was too generic of a name, use from_ptx instead, passing an empty slice of options (usually)"
272
+ ) ]
71
273
#[ allow( clippy:: should_implement_trait) ]
72
274
pub fn from_str < T : AsRef < str > > ( string : T ) -> CudaResult < Module > {
73
275
let cstr = CString :: new ( string. as_ref ( ) )
74
276
. expect ( "string given to Module::from_str contained nul bytes" ) ;
277
+ #[ allow( deprecated) ]
75
278
Self :: load_from_string ( cstr. as_c_str ( ) )
76
279
}
77
280
@@ -98,6 +301,12 @@ impl Module {
98
301
/// # Ok(())
99
302
/// # }
100
303
/// ```
304
+ #[ deprecated(
305
+ since = "0.3.0" ,
306
+ note = "load_from_string was an inconsistent name with inconsistent params, use from_ptx/from_ptx_cstr, passing
307
+ an empty slice of options (usually)
308
+ "
309
+ ) ]
101
310
pub fn load_from_string ( image : & CStr ) -> CudaResult < Module > {
102
311
unsafe {
103
312
let mut module = Module {
0 commit comments