Skip to content

Commit fa27a45

Browse files
committed
Feat: add rot to blastoff
1 parent d7283b8 commit fa27a45

File tree

3 files changed

+112
-5
lines changed

3 files changed

+112
-5
lines changed

crates/blastoff/src/context.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ bitflags::bitflags! {
5757
/// - [Unconjugated Complex Dot Product <span style="float:right;">`dotu`</span>](CublasContext::dotu)
5858
/// - [Conjugated Complex Dot Product <span style="float:right;">`dotc`</span>](CublasContext::dotc)
5959
/// - [Euclidian Norm <span style="float:right;">`nrm2`</span>](CublasContext::nrm2)
60+
/// - [Rotate points in the xy-plane using a Givens rotation matrix <span style="float:right;">`rot`</span>](CublasContext::rot)
6061
#[derive(Debug)]
6162
pub struct CublasContext {
6263
pub(crate) raw: sys::v2::cublasHandle_t,

crates/blastoff/src/level1.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,4 +535,110 @@ impl CublasContext {
535535
) -> Result {
536536
self.nrm2_strided(stream, n, x, None, result)
537537
}
538+
539+
/// Same as [`CublasContext::rot`] but with an explicit stride.
540+
///
541+
/// # Panics
542+
///
543+
/// Panics if the buffers are not long enough for the stride and length requested.
544+
pub fn rot_strided<T: Level1>(
545+
&mut self,
546+
stream: &Stream,
547+
n: usize,
548+
x: &mut impl GpuBuffer<T>,
549+
x_stride: Option<usize>,
550+
y: &mut impl GpuBuffer<T>,
551+
y_stride: Option<usize>,
552+
c: &impl GpuBox<T::FloatTy>,
553+
s: &impl GpuBox<T::FloatTy>,
554+
) -> Result {
555+
check_stride(x, n, x_stride);
556+
check_stride(y, n, y_stride);
557+
558+
self.with_stream(stream, |ctx| unsafe {
559+
Ok(T::rot(
560+
ctx.raw,
561+
n as i32,
562+
x.as_device_ptr().as_raw_mut(),
563+
x_stride.unwrap_or(1) as i32,
564+
y.as_device_ptr().as_raw_mut(),
565+
y_stride.unwrap_or(1) as i32,
566+
c.as_device_ptr().as_raw(),
567+
s.as_device_ptr().as_raw(),
568+
)
569+
.to_result()?)
570+
})
571+
}
572+
573+
/// Rotates points in the xy-plane using a Givens rotation matrix.
574+
///
575+
/// Rotation matrix:
576+
///
577+
/// <p>
578+
/// $$
579+
/// \begin{pmatrix}
580+
/// c & s \\
581+
/// -s & c
582+
/// \end{pmatrix}
583+
/// $$
584+
/// </p>
585+
///
586+
/// Therefore:
587+
///
588+
/// $$
589+
/// \boldsymbol{x}_i = \boldsymbol{x}_ic + \boldsymbol{y}_is
590+
/// $$
591+
///
592+
/// And:
593+
///
594+
/// $$
595+
/// \boldsymbol{y}_i = -\boldsymbol{x}_is + \boldsymbol{y}_ic
596+
/// $$
597+
///
598+
/// Where $c$ and $s$ are usually
599+
///
600+
/// <p>
601+
/// $$
602+
/// c = cos(\theta) \\
603+
/// s = sin(\theta)
604+
/// $$
605+
/// </p>
606+
///
607+
/// # Example
608+
///
609+
/// ```
610+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
611+
/// # let _a = cust::quick_init()?;
612+
/// # use blastoff::CublasContext;
613+
/// # use cust::prelude::*;
614+
/// # use cust::memory::DeviceBox;
615+
/// # use cust::util::SliceExt;
616+
/// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
617+
/// let mut ctx = CublasContext::new()?;
618+
/// let mut x = [1.0f32].as_dbuf()?;
619+
/// let mut y = [0.0].as_dbuf()?;
620+
/// let c = DeviceBox::new(&1.0)?;
621+
/// let s = DeviceBox::new(&0.0)?;
622+
///
623+
/// ctx.rot(&stream, x.len(), &mut x, &mut y, &c, &s)?;
624+
///
625+
/// stream.synchronize()?;
626+
///
627+
/// // identity matrix
628+
/// assert_eq!(&x.as_host_vec()?, &[1.0]);
629+
/// assert_eq!(&y.as_host_vec()?, &[0.0]);
630+
/// # Ok(())
631+
/// # }
632+
/// ```
633+
pub fn rot<T: Level1>(
634+
&mut self,
635+
stream: &Stream,
636+
n: usize,
637+
x: &mut impl GpuBuffer<T>,
638+
y: &mut impl GpuBuffer<T>,
639+
c: &impl GpuBox<T::FloatTy>,
640+
s: &impl GpuBox<T::FloatTy>,
641+
) -> Result {
642+
self.rot_strided(stream, n, x, None, y, None, c, s)
643+
}
538644
}

crates/blastoff/src/raw/level1.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pub trait Level1: BlasDatatype {
5050
y: *mut Self,
5151
incy: c_int,
5252
c: *const Self::FloatTy,
53-
s: *const Self,
53+
s: *const Self::FloatTy,
5454
) -> cublasStatus_t;
5555
unsafe fn rotg(
5656
handle: cublasHandle_t,
@@ -315,9 +315,9 @@ impl Level1 for Complex32 {
315315
y: *mut Self,
316316
incy: c_int,
317317
c: *const Self::FloatTy,
318-
s: *const Self,
318+
s: *const Self::FloatTy,
319319
) -> cublasStatus_t {
320-
cublasCrot_v2(handle, n, x.cast(), incx, y.cast(), incy, c, s.cast())
320+
cublasCsrot_v2(handle, n, x.cast(), incx, y.cast(), incy, c, s)
321321
}
322322
unsafe fn rotg(
323323
handle: cublasHandle_t,
@@ -406,9 +406,9 @@ impl Level1 for Complex64 {
406406
y: *mut Self,
407407
incy: c_int,
408408
c: *const Self::FloatTy,
409-
s: *const Self,
409+
s: *const Self::FloatTy,
410410
) -> cublasStatus_t {
411-
cublasZrot_v2(handle, n, x.cast(), incx, y.cast(), incy, c, s.cast())
411+
cublasZdrot_v2(handle, n, x.cast(), incx, y.cast(), incy, c, s)
412412
}
413413
unsafe fn rotg(
414414
handle: cublasHandle_t,

0 commit comments

Comments
 (0)