Skip to content

Commit 45b334d

Browse files
committed
Feat: lay down foundation for cuBLAS work
1 parent 3920058 commit 45b334d

File tree

14 files changed

+6449
-0
lines changed

14 files changed

+6449
-0
lines changed

crates/blastoff/Cargo.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[package]
2+
name = "blastoff"
3+
version = "0.1.0"
4+
edition = "2021"
5+
authors = ["Riccardo D'Ambrosio <[email protected]>"]
6+
repository = "https://github.com/Rust-GPU/Rust-CUDA"
7+
8+
[dependencies]
9+
bitflags = "1.3.2"
10+
cublas_sys = { version = "0.1", path = "../cublas_sys" }
11+
cust = { version = "0.2", path = "../cust", features = ["num-complex"] }
12+
num-complex = "0.4.0"

crates/blastoff/src/context.rs

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
use crate::{error::*, sys};
2+
use cust::stream::Stream;
3+
use std::ffi::CString;
4+
use std::mem::{self, MaybeUninit};
5+
use std::os::raw::c_char;
6+
use std::ptr;
7+
8+
type Result<T, E = Error> = std::result::Result<T, E>;
9+
10+
bitflags::bitflags! {
11+
/// Configures precision levels for the math in cuBLAS.
12+
#[derive(Default)]
13+
pub struct MathMode: u32 {
14+
/// Highest performance mode which uses compute and intermediate storage precisions
15+
/// with at least the same number of mantissa and exponent bits as requested. Will
16+
/// also use tensor cores when possible.
17+
const DEFAULT = 0;
18+
/// Mode which uses prescribed precision and standardized arithmetic for all phases of calculations
19+
/// and is primarily intended for numerical robustness studies, testing, and debugging. This mode
20+
/// might not be as performant as the other modes.
21+
const PEDANTIC = 1;
22+
/// Enable acceleration of single precision routines using TF32 Tensor Cores.
23+
const TF32_TENSOR_OP = 3;
24+
/// Forces any reductions during matrix multiplication to use the accumulator type (i.e. the compute type)
25+
/// and not the output type in case of mixed precision routines where output precision is less than compute
26+
/// type precision.
27+
const DISALLOW_REDUCED_PRECISION_REDUCTION = 16;
28+
}
29+
}
30+
31+
/// The central structure required to do anything with cuBLAS. It holds and manages internal memory allocations
32+
///
33+
/// # Multithreaded Usage
34+
///
35+
/// While it is technically allowed to use the same context across threads, it is very suboptimal and dangerous
36+
/// so we do not expose this functionality. Instead, you should create a context for every thread (as the cuBLAS docs reccomend).
37+
///
38+
/// # Multi-Device Usage
39+
///
40+
/// cuBLAS contexts are tied to the current device (through the current CUDA Context), therefore, for multi-device usage you should
41+
/// create a context for every device.
42+
///
43+
/// # Drop Cost
44+
///
45+
/// cuBLAS contexts hold internal memory allocations required by the library, and will free those allocations on drop. They will
46+
/// also synchronize the entire device when dropping the context. Therefore, you should minimize both the amount of contexts, and the
47+
/// amount of context drops. You should generally allocate all the contexts at once, and drop them all at once.
48+
#[derive(Debug)]
49+
pub struct CublasContext {
50+
pub(crate) raw: sys::v2::cublasHandle_t,
51+
}
52+
53+
impl CublasContext {
54+
/// Creates a new cuBLAS context, allocating all of the required host and device memory.
55+
///
56+
/// # Example
57+
///
58+
/// ```
59+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
60+
/// # let _ctx = blastoff::__doctest_setup();
61+
/// use blastoff::context::CublasContext;
62+
/// let ctx = CublasContext::new()?;
63+
/// # Ok(())
64+
/// # }
65+
/// ```
66+
pub fn new() -> Result<Self> {
67+
let mut raw = MaybeUninit::uninit();
68+
unsafe {
69+
sys::v2::cublasCreate_v2(raw.as_mut_ptr()).to_result()?;
70+
sys::v2::cublasSetPointerMode_v2(
71+
raw.assume_init(),
72+
sys::v2::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
73+
)
74+
.to_result()?;
75+
Ok(Self {
76+
raw: raw.assume_init(),
77+
})
78+
}
79+
}
80+
81+
/// Tries to destroy a [`CublasContext`], returning an error if it fails.
82+
pub fn drop(mut ctx: CublasContext) -> DropResult<CublasContext> {
83+
if ctx.raw.is_null() {
84+
return Ok(());
85+
}
86+
87+
unsafe {
88+
let inner = mem::replace(&mut ctx.raw, ptr::null_mut());
89+
match sys::v2::cublasDestroy_v2(inner).to_result() {
90+
Ok(()) => {
91+
mem::forget(ctx);
92+
Ok(())
93+
}
94+
Err(e) => Err((e, CublasContext { raw: inner })),
95+
}
96+
}
97+
}
98+
99+
/// Returns the major, minor, and patch versions of the cuBLAS library.
100+
pub fn version(&self) -> (u32, u32, u32) {
101+
let mut raw = MaybeUninit::<u32>::uninit();
102+
unsafe {
103+
// getVersion can't fail
104+
sys::v2::cublasGetVersion_v2(self.raw, raw.as_mut_ptr().cast())
105+
.to_result()
106+
.unwrap();
107+
108+
let raw = raw.assume_init();
109+
(raw / 1000, (raw % 1000) / 100, raw % 100)
110+
}
111+
}
112+
113+
/// Executes a given closure in a specific CUDA [`Stream`], specifically, it sets the current cublas stream
114+
/// for the context, runs the closure, then unsets the stream back to NULL.
115+
pub fn with_stream<T, F: FnOnce(&mut Self) -> Result<T>>(
116+
&mut self,
117+
stream: &Stream,
118+
func: F,
119+
) -> Result<T> {
120+
unsafe {
121+
// cudaStream_t is the same as CUstream
122+
sys::v2::cublasSetStream_v2(self.raw, mem::transmute(stream.as_inner())).to_result()?;
123+
let res = func(self)?;
124+
// reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
125+
// execute a raw sys function with the context's handle.
126+
sys::v2::cublasSetStream_v2(self.raw, ptr::null_mut()).to_result()?;
127+
Ok(res)
128+
}
129+
}
130+
131+
/// Sets whether the cuBLAS library is allowed to use atomics for certain routines such as `symv` or `hemv`.
132+
///
133+
/// cuBLAS has specialized versions of functions that use atomics to accumulate results, which is generally significantly
134+
/// faster than not using atomics. However, atomics generate results that are not strictly identical from one run to another.
135+
/// Such differences are mathematically insignificant, but when debugging, the differences are less than ideal.
136+
///
137+
/// This function sets whether atomics usage is allowed or not, unless explicitly specified in function docs, functions
138+
/// do not have an atomic specialization of the function.
139+
///
140+
/// This is `false` by default (cuBLAS will not use atomics unless explicitly set to be allowed to do so).
141+
///
142+
/// # Example
143+
///
144+
/// ```
145+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
146+
/// # let _ctx = blastoff::__doctest_setup();
147+
/// use blastoff::context::CublasContext;
148+
/// let ctx = CublasContext::new()?;
149+
/// // allows cuBLAS to use atomics to speed up functions at the cost of determinism.
150+
/// ctx.set_atomics_mode(true)?;
151+
/// # Ok(())
152+
/// # }
153+
/// ```
154+
pub fn set_atomics_mode(&self, allowed: bool) -> Result<()> {
155+
unsafe {
156+
Ok(sys::v2::cublasSetAtomicsMode(
157+
self.raw,
158+
if allowed {
159+
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED
160+
} else {
161+
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED
162+
},
163+
)
164+
.to_result()?)
165+
}
166+
}
167+
168+
/// Returns whether the context is set to be allowed to use atomics per [`CublasContext::set_atomics_mode`].
169+
/// Returns `false` unless previously explicitly set to `true`.
170+
///
171+
/// # Example
172+
///
173+
/// ```
174+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
175+
/// # let _ctx = blastoff::__doctest_setup();
176+
/// use blastoff::context::CublasContext;
177+
/// let ctx = CublasContext::new()?;
178+
/// ctx.set_atomics_mode(true)?;
179+
/// assert!(ctx.get_atomics_mode()?);
180+
/// # Ok(())
181+
/// # }
182+
/// ```
183+
pub fn get_atomics_mode(&self) -> Result<bool> {
184+
let mut mode = MaybeUninit::uninit();
185+
unsafe {
186+
sys::v2::cublasGetAtomicsMode(self.raw, mode.as_mut_ptr()).to_result()?;
187+
Ok(match mode.assume_init() {
188+
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => true,
189+
sys::v2::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => false,
190+
})
191+
}
192+
}
193+
194+
/// Sets the precision level for different routines in cuBLAS. See [`MathMode`] for more info.
195+
///
196+
/// # Example
197+
///
198+
/// ```
199+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
200+
/// # let _ctx = blastoff::__doctest_setup();
201+
/// use blastoff::context::{CublasContext, MathMode};
202+
/// let ctx = CublasContext::new()?;
203+
/// ctx.set_math_mode(MathMode::DEFAULT | MathMode::DISALLOW_REDUCED_PRECISION_REDUCTION)?;
204+
/// # Ok(())
205+
/// # }
206+
/// ```
207+
pub fn set_math_mode(&self, math_mode: MathMode) -> Result<()> {
208+
unsafe {
209+
Ok(
210+
sys::v2::cublasSetMathMode(self.raw, mem::transmute(math_mode.bits()))
211+
.to_result()?,
212+
)
213+
}
214+
}
215+
216+
/// Gets the precision level that was previously set by [`CublasContext::set_math_mode`].
217+
///
218+
/// # Example
219+
///
220+
/// ```
221+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
222+
/// # let _ctx = blastoff::__doctest_setup();
223+
/// use blastoff::context::{CublasContext, MathMode};
224+
/// let ctx = CublasContext::new()?;
225+
/// ctx.set_math_mode(MathMode::DEFAULT | MathMode::DISALLOW_REDUCED_PRECISION_REDUCTION)?;
226+
/// assert_eq!(ctx.get_math_mode()?, MathMode::DEFAULT | MathMode::DISALLOW_REDUCED_PRECISION_REDUCTION);
227+
/// # Ok(())
228+
/// # }
229+
/// ```
230+
pub fn get_math_mode(&self) -> Result<MathMode> {
231+
let mut mode = MaybeUninit::uninit();
232+
unsafe {
233+
sys::v2::cublasGetMathMode(self.raw, mode.as_mut_ptr()).to_result()?;
234+
Ok(MathMode::from_bits(mode.assume_init() as u32)
235+
.expect("Invalid MathMode from cuBLAS"))
236+
}
237+
}
238+
239+
/// Configures cuBLAS logging.
240+
///
241+
/// - `enable` will enable or disable logging completely. Off by default.
242+
/// - `log_to_stdout` will turn on/off logging to standard output. Off by default.
243+
/// - `log_to_stderr` will turn on/off logging to standard error. Off by default.
244+
/// - `log_file_name` will turn on/off logging to a file in the file system. None by default.
245+
///
246+
/// # Example
247+
///
248+
/// ```
249+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
250+
/// # let _ctx = blastoff::__doctest_setup();
251+
/// use blastoff::context::{CublasContext, MathMode};
252+
/// let ctx = CublasContext::new()?;
253+
/// // turn off logging completely
254+
/// ctx.configure_logger(false, false, false, None);
255+
/// // log to stdout and stderr
256+
/// ctx.configure_logger(true, true, true, None);
257+
/// // log to a file
258+
/// ctx.configure_logger(true, false, false, Some("./log.txt"));
259+
/// # Ok(())
260+
/// # }
261+
/// ```
262+
pub fn configure_logger(
263+
&self,
264+
enable: bool,
265+
log_to_stdout: bool,
266+
log_to_stderr: bool,
267+
log_file_name: Option<&str>,
268+
) {
269+
unsafe {
270+
let path = log_file_name.map(|p| CString::new(p).expect("nul in log_file_name"));
271+
let path_ptr = path.map_or(ptr::null(), |s| s.as_ptr());
272+
273+
sys::v2::cublasLoggerConfigure(
274+
enable as i32,
275+
log_to_stdout as i32,
276+
log_to_stderr as i32,
277+
path_ptr,
278+
)
279+
.to_result()
280+
.expect("logger configure failed");
281+
}
282+
}
283+
284+
/// Sets a function for the logger callback.
285+
///
286+
/// # Safety
287+
///
288+
/// The callback must not panic and unwind.
289+
pub unsafe fn set_logger_callback(callback: Option<unsafe extern "C" fn(*const c_char)>) {
290+
sys::v2::cublasSetLoggerCallback(callback)
291+
.to_result()
292+
.unwrap();
293+
}
294+
295+
/// Gets the logger callback that was previously set.
296+
pub fn get_logger_callback() -> Option<unsafe extern "C" fn(*const c_char)> {
297+
let mut cb = MaybeUninit::uninit();
298+
unsafe {
299+
sys::v2::cublasGetLoggerCallback(cb.as_mut_ptr())
300+
.to_result()
301+
.unwrap();
302+
cb.assume_init()
303+
}
304+
}
305+
}
306+
307+
impl Drop for CublasContext {
308+
fn drop(&mut self) {
309+
unsafe {
310+
sys::v2::cublasDestroy_v2(self.raw);
311+
}
312+
}
313+
}

0 commit comments

Comments
 (0)