Skip to content

Commit d7283b8

Browse files
committed
Feat: add dot, dotu, dotc, and nrm2 to blastoff, fix doctests
1 parent 736ea19 commit d7283b8

File tree

3 files changed

+378
-17
lines changed

3 files changed

+378
-17
lines changed

crates/blastoff/src/context.rs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ bitflags::bitflags! {
5353
/// - [Index of largest element by absolute value <span style="float:right;">`amax`</span>](CublasContext::amax)
5454
/// - [$\alpha \boldsymbol{x} + \boldsymbol{y}$ <span style="float:right;">`axpy`</span>](CublasContext::axpy)
5555
/// - [Copy $n$ elements from $\boldsymbol{x}$ into $\boldsymbol{y}$ <span style="float:right;">`copy`</span>](CublasContext::copy)
56+
/// - [Dot Product <span style="float:right;">`dot`</span>](CublasContext::dot)
57+
/// - [Unconjugated Complex Dot Product <span style="float:right;">`dotu`</span>](CublasContext::dotu)
58+
/// - [Conjugated Complex Dot Product <span style="float:right;">`dotc`</span>](CublasContext::dotc)
59+
/// - [Euclidian Norm <span style="float:right;">`nrm2`</span>](CublasContext::nrm2)
5660
#[derive(Debug)]
5761
pub struct CublasContext {
5862
pub(crate) raw: sys::v2::cublasHandle_t,
@@ -65,8 +69,8 @@ impl CublasContext {
6569
///
6670
/// ```
6771
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
68-
/// # let _ctx = blastoff::__doctest_setup();
69-
/// use blastoff::context::CublasContext;
72+
/// # let _ctx = cust::quick_init()?;
73+
/// use blastoff::CublasContext;
7074
/// let ctx = CublasContext::new()?;
7175
/// # Ok(())
7276
/// # }
@@ -151,8 +155,8 @@ impl CublasContext {
151155
///
152156
/// ```
153157
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
154-
/// # let _ctx = blastoff::__doctest_setup();
155-
/// use blastoff::context::CublasContext;
158+
/// # let _ctx = cust::quick_init()?;
159+
/// use blastoff::CublasContext;
156160
/// let ctx = CublasContext::new()?;
157161
/// // allows cuBLAS to use atomics to speed up functions at the cost of determinism.
158162
/// ctx.set_atomics_mode(true)?;
@@ -180,8 +184,8 @@ impl CublasContext {
180184
///
181185
/// ```
182186
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
183-
/// # let _ctx = blastoff::__doctest_setup();
184-
/// use blastoff::context::CublasContext;
187+
/// # let _ctx = cust::quick_init()?;
188+
/// use blastoff::CublasContext;
185189
/// let ctx = CublasContext::new()?;
186190
/// ctx.set_atomics_mode(true)?;
187191
/// assert!(ctx.get_atomics_mode()?);
@@ -205,8 +209,8 @@ impl CublasContext {
205209
///
206210
/// ```
207211
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
208-
/// # let _ctx = blastoff::__doctest_setup();
209-
/// use blastoff::context::{CublasContext, MathMode};
212+
/// # let _ctx = cust::quick_init()?;
213+
/// use blastoff::{CublasContext, MathMode};
210214
/// let ctx = CublasContext::new()?;
211215
/// ctx.set_math_mode(MathMode::DEFAULT | MathMode::DISALLOW_REDUCED_PRECISION_REDUCTION)?;
212216
/// # Ok(())
@@ -227,8 +231,8 @@ impl CublasContext {
227231
///
228232
/// ```
229233
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
230-
/// # let _ctx = blastoff::__doctest_setup();
231-
/// use blastoff::context::{CublasContext, MathMode};
234+
/// # let _ctx = cust::quick_init()?;
235+
/// use blastoff::{CublasContext, MathMode};
232236
/// let ctx = CublasContext::new()?;
233237
/// ctx.set_math_mode(MathMode::DEFAULT | MathMode::DISALLOW_REDUCED_PRECISION_REDUCTION)?;
234238
/// assert_eq!(ctx.get_math_mode()?, MathMode::DEFAULT | MathMode::DISALLOW_REDUCED_PRECISION_REDUCTION);
@@ -255,8 +259,8 @@ impl CublasContext {
255259
///
256260
/// ```
257261
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
258-
/// # let _ctx = blastoff::__doctest_setup();
259-
/// use blastoff::context::{CublasContext, MathMode};
262+
/// # let _ctx = cust::quick_init()?;
263+
/// use blastoff::{CublasContext, MathMode};
260264
/// let ctx = CublasContext::new()?;
261265
/// // turn off logging completely
262266
/// ctx.configure_logger(false, false, false, None);

crates/blastoff/src/level1.rs

Lines changed: 249 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use crate::{
44
context::CublasContext,
55
error::{Error, ToResult},
6-
raw::Level1,
6+
raw::{ComplexLevel1, FloatLevel1, Level1},
77
BlasDatatype,
88
};
99
use cust::memory::{GpuBox, GpuBuffer};
@@ -45,7 +45,7 @@ impl CublasContext {
4545
self.with_stream(stream, |ctx| unsafe {
4646
Ok(T::amin(
4747
ctx.raw,
48-
x.len() as i32,
48+
n as i32,
4949
x.as_device_ptr().as_raw(),
5050
stride.unwrap_or(1) as i32,
5151
result.as_device_ptr().as_raw_mut(),
@@ -107,7 +107,7 @@ impl CublasContext {
107107
self.with_stream(stream, |ctx| unsafe {
108108
Ok(T::amax(
109109
ctx.raw,
110-
x.len() as i32,
110+
n as i32,
111111
x.as_device_ptr().as_raw(),
112112
stride.unwrap_or(1) as i32,
113113
result.as_device_ptr().as_raw_mut(),
@@ -171,7 +171,7 @@ impl CublasContext {
171171
self.with_stream(stream, |ctx| unsafe {
172172
Ok(T::axpy(
173173
ctx.raw,
174-
x.len() as i32,
174+
n as i32,
175175
alpha.as_device_ptr().as_raw(),
176176
x.as_device_ptr().as_raw(),
177177
x_stride.unwrap_or(1) as i32,
@@ -244,7 +244,7 @@ impl CublasContext {
244244
self.with_stream(stream, |ctx| unsafe {
245245
Ok(T::copy(
246246
ctx.raw,
247-
x.len() as i32,
247+
n as i32,
248248
x.as_device_ptr().as_raw(),
249249
x_stride.unwrap_or(1) as i32,
250250
y.as_device_ptr().as_raw_mut(),
@@ -291,4 +291,248 @@ impl CublasContext {
291291
) -> Result {
292292
self.copy_strided(stream, n, x, None, y, None)
293293
}
294+
295+
/// Same as [`CublasContext::dot`] but with an explicit stride.
296+
///
297+
/// # Panics
298+
///
299+
/// Panics if the buffers are not long enough for the stride and length requested.
300+
pub fn dot_strided<T: FloatLevel1>(
301+
&mut self,
302+
stream: &Stream,
303+
n: usize,
304+
x: &impl GpuBuffer<T>,
305+
x_stride: Option<usize>,
306+
y: &impl GpuBuffer<T>,
307+
y_stride: Option<usize>,
308+
result: &mut impl GpuBox<T>,
309+
) -> Result {
310+
check_stride(x, n, x_stride);
311+
check_stride(y, n, y_stride);
312+
313+
self.with_stream(stream, |ctx| unsafe {
314+
Ok(T::dot(
315+
ctx.raw,
316+
n as i32,
317+
x.as_device_ptr().as_raw(),
318+
x_stride.unwrap_or(1) as i32,
319+
y.as_device_ptr().as_raw(),
320+
y_stride.unwrap_or(1) as i32,
321+
result.as_device_ptr().as_raw_mut(),
322+
)
323+
.to_result()?)
324+
})
325+
}
326+
327+
/// Computes the dot product of two vectors:
328+
///
329+
/// $$
330+
/// \sum^n_{i=1} \boldsymbol{x}_i * \boldsymbol{y}_i
331+
/// $$
332+
///
333+
/// # Panics
334+
///
335+
/// Panics if the buffers are not long enough for the length requested.
336+
///
337+
/// # Example
338+
///
339+
/// ```
340+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
341+
/// # let _a = cust::quick_init()?;
342+
/// # use blastoff::CublasContext;
343+
/// # use cust::prelude::*;
344+
/// # use cust::memory::DeviceBox;
345+
/// # use cust::util::SliceExt;
346+
/// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
347+
/// let mut ctx = CublasContext::new()?;
348+
/// let x = [1.0f32, 2.0, 3.0, 4.0].as_dbuf()?;
349+
/// let y = [1.0f32, 2.0, 3.0, 4.0].as_dbuf()?;
350+
/// let mut result = DeviceBox::new(&0.0)?;
351+
///
352+
/// ctx.dot(&stream, x.len(), &x, &y, &mut result)?;
353+
///
354+
/// stream.synchronize()?;
355+
///
356+
/// assert_eq!(result.as_host_value()?, 30.0);
357+
/// # Ok(())
358+
/// # }
359+
/// ```
360+
pub fn dot<T: FloatLevel1>(
361+
&mut self,
362+
stream: &Stream,
363+
n: usize,
364+
x: &impl GpuBuffer<T>,
365+
y: &impl GpuBuffer<T>,
366+
result: &mut impl GpuBox<T>,
367+
) -> Result {
368+
self.dot_strided(stream, n, x, None, y, None, result)
369+
}
370+
371+
/// Same as [`CublasContext::dotu`] but with an explicit stride.
372+
///
373+
/// # Panics
374+
///
375+
/// Panics if the buffers are not long enough for the stride and length requested.
376+
pub fn dotu_strided<T: ComplexLevel1>(
377+
&mut self,
378+
stream: &Stream,
379+
n: usize,
380+
x: &impl GpuBuffer<T>,
381+
x_stride: Option<usize>,
382+
y: &impl GpuBuffer<T>,
383+
y_stride: Option<usize>,
384+
result: &mut impl GpuBox<T>,
385+
) -> Result {
386+
check_stride(x, n, x_stride);
387+
check_stride(y, n, y_stride);
388+
389+
self.with_stream(stream, |ctx| unsafe {
390+
Ok(T::dotu(
391+
ctx.raw,
392+
n as i32,
393+
x.as_device_ptr().as_raw(),
394+
x_stride.unwrap_or(1) as i32,
395+
y.as_device_ptr().as_raw(),
396+
y_stride.unwrap_or(1) as i32,
397+
result.as_device_ptr().as_raw_mut(),
398+
)
399+
.to_result()?)
400+
})
401+
}
402+
403+
/// Computes the unconjugated dot product of two vectors of complex numbers.
404+
///
405+
/// # Panics
406+
///
407+
/// Panics if the buffers are not long enough for the length requested.
408+
pub fn dotu<T: ComplexLevel1>(
409+
&mut self,
410+
stream: &Stream,
411+
n: usize,
412+
x: &impl GpuBuffer<T>,
413+
y: &impl GpuBuffer<T>,
414+
result: &mut impl GpuBox<T>,
415+
) -> Result {
416+
self.dotu_strided(stream, n, x, None, y, None, result)
417+
}
418+
419+
/// Same as [`CublasContext::dotc`] but with an explicit stride.
420+
///
421+
/// # Panics
422+
///
423+
/// Panics if the buffers are not long enough for the stride and length requested.
424+
pub fn dotc_strided<T: ComplexLevel1>(
425+
&mut self,
426+
stream: &Stream,
427+
n: usize,
428+
x: &impl GpuBuffer<T>,
429+
x_stride: Option<usize>,
430+
y: &impl GpuBuffer<T>,
431+
y_stride: Option<usize>,
432+
result: &mut impl GpuBox<T>,
433+
) -> Result {
434+
check_stride(x, n, x_stride);
435+
check_stride(y, n, y_stride);
436+
437+
self.with_stream(stream, |ctx| unsafe {
438+
Ok(T::dotc(
439+
ctx.raw,
440+
n as i32,
441+
x.as_device_ptr().as_raw(),
442+
x_stride.unwrap_or(1) as i32,
443+
y.as_device_ptr().as_raw(),
444+
y_stride.unwrap_or(1) as i32,
445+
result.as_device_ptr().as_raw_mut(),
446+
)
447+
.to_result()?)
448+
})
449+
}
450+
451+
/// Computes the conjugated dot product of two vectors of complex numbers.
452+
///
453+
/// # Panics
454+
///
455+
/// Panics if the buffers are not long enough for the length requested.
456+
pub fn dotc<T: ComplexLevel1>(
457+
&mut self,
458+
stream: &Stream,
459+
n: usize,
460+
x: &impl GpuBuffer<T>,
461+
y: &impl GpuBuffer<T>,
462+
result: &mut impl GpuBox<T>,
463+
) -> Result {
464+
self.dotc_strided(stream, n, x, None, y, None, result)
465+
}
466+
467+
/// Same as [`CublasContext::nrm2`] but with an explicit stride.
468+
///
469+
/// # Panics
470+
///
471+
/// Panics if the buffers are not long enough for the stride and length requested.
472+
pub fn nrm2_strided<T: Level1>(
473+
&mut self,
474+
stream: &Stream,
475+
n: usize,
476+
x: &impl GpuBuffer<T>,
477+
x_stride: Option<usize>,
478+
result: &mut impl GpuBox<T::FloatTy>,
479+
) -> Result {
480+
check_stride(x, n, x_stride);
481+
482+
self.with_stream(stream, |ctx| unsafe {
483+
Ok(T::nrm2(
484+
ctx.raw,
485+
n as i32,
486+
x.as_device_ptr().as_raw(),
487+
x_stride.unwrap_or(1) as i32,
488+
result.as_device_ptr().as_raw_mut(),
489+
)
490+
.to_result()?)
491+
})
492+
}
493+
494+
/// Computes the euclidian norm of a vector, in other words, the square root of
495+
/// the sum of the squares of each element in `x`:
496+
///
497+
/// $$
498+
/// \sqrt{\sum_{i=1}^n (\boldsymbol{x}_i^2)}
499+
/// $$
500+
///
501+
/// # Panics
502+
///
503+
/// Panics if `x` is not large enough for the requested length `n`.
504+
///
505+
/// # Example
506+
///
507+
/// ```
508+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
509+
/// # let _a = cust::quick_init()?;
510+
/// # use blastoff::CublasContext;
511+
/// # use cust::prelude::*;
512+
/// # use cust::memory::DeviceBox;
513+
/// # use cust::util::SliceExt;
514+
/// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
515+
/// let mut ctx = CublasContext::new()?;
516+
/// let x = [2.0f32; 4].as_dbuf()?;
517+
/// let mut result = DeviceBox::new(&0.0f32)?;
518+
///
519+
/// ctx.nrm2(&stream, x.len(), &x, &mut result)?;
520+
///
521+
/// stream.synchronize()?;
522+
///
523+
/// let result = result.as_host_value()?;
524+
/// // float weirdness
525+
/// assert!(result >= 3.9 && result <= 4.0);
526+
/// # Ok(())
527+
/// # }
528+
/// ```
529+
pub fn nrm2<T: Level1>(
530+
&mut self,
531+
stream: &Stream,
532+
n: usize,
533+
x: &impl GpuBuffer<T>,
534+
result: &mut impl GpuBox<T::FloatTy>,
535+
) -> Result {
536+
self.nrm2_strided(stream, n, x, None, result)
537+
}
294538
}

0 commit comments

Comments
 (0)