Skip to content

Commit 1a59d22

Browse files
committed
Feat: add copy to blastoff
1 parent d2dffa3 commit 1a59d22

File tree

3 files changed

+74
-3
lines changed

3 files changed

+74
-3
lines changed

crates/blastoff/src/context.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ bitflags::bitflags! {
5252
/// - [Index of smallest element by absolute value <span style="float:right;">`amin`</span>](CublasContext::amin)
5353
/// - [Index of largest element by absolute value <span style="float:right;">`amax`</span>](CublasContext::amax)
5454
/// - [$\alpha \boldsymbol{x} + \boldsymbol{y}$ <span style="float:right;">`axpy`</span>](CublasContext::axpy)
55+
/// - [Copy $n$ elements from $\boldsymbol{x}$ into $\boldsymbol{y}$ <span style="float:right;">`copy`</span>](CublasContext::copy)
5556
#[derive(Debug)]
5657
pub struct CublasContext {
5758
pub(crate) raw: sys::v2::cublasHandle_t,

crates/blastoff/src/level1.rs

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl CublasContext {
6262
/// ```
6363
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
6464
/// # let _a = cust::quick_init()?;
65-
/// # use blastoff::context::CublasContext;
65+
/// # use blastoff::CublasContext;
6666
/// # use cust::prelude::*;
6767
/// # use cust::memory::DeviceBox;
6868
/// # use cust::util::SliceExt;
@@ -124,7 +124,7 @@ impl CublasContext {
124124
/// ```
125125
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
126126
/// # let _a = cust::quick_init()?;
127-
/// # use blastoff::context::CublasContext;
127+
/// # use blastoff::CublasContext;
128128
/// # use cust::prelude::*;
129129
/// # use cust::memory::DeviceBox;
130130
/// # use cust::util::SliceExt;
@@ -194,7 +194,7 @@ impl CublasContext {
194194
/// ```
195195
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
196196
/// # let _a = cust::quick_init()?;
197-
/// # use blastoff::context::CublasContext;
197+
/// # use blastoff::CublasContext;
198198
/// # use cust::prelude::*;
199199
/// # use cust::memory::DeviceBox;
200200
/// # use cust::util::SliceExt;
@@ -223,4 +223,72 @@ impl CublasContext {
223223
) -> Result {
224224
self.axpy_strided(stream, alpha, n, x, None, y, None)
225225
}
226+
227+
/// Same as [`CublasContext::copy`] but with an explicit stride.
228+
///
229+
/// # Panics
230+
///
231+
/// Panics if the buffers are not long enough for the stride and length requested.
232+
pub fn copy_strided<T: Level1>(
233+
&mut self,
234+
stream: &Stream,
235+
n: usize,
236+
x: &impl GpuBuffer<T>,
237+
x_stride: Option<usize>,
238+
y: &mut impl GpuBuffer<T>,
239+
y_stride: Option<usize>,
240+
) -> Result {
241+
check_stride(x, n, x_stride);
242+
check_stride(y, n, y_stride);
243+
244+
self.with_stream(stream, |ctx| unsafe {
245+
Ok(T::copy(
246+
ctx.raw,
247+
x.len() as i32,
248+
x.as_device_ptr().as_raw(),
249+
x_stride.unwrap_or(1) as i32,
250+
y.as_device_ptr().as_raw_mut(),
251+
y_stride.unwrap_or(1) as i32,
252+
)
253+
.to_result()?)
254+
})
255+
}
256+
257+
/// Copies `n` elements from `x` into `y`, overriding any previous data inside `y`.
258+
///
259+
/// # Panics
260+
///
261+
/// Panics if `x` or `y` are not large enough for the requested amount of elements.
262+
///
263+
/// # Example
264+
///
265+
/// ```
266+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
267+
/// # let _a = cust::quick_init()?;
268+
/// # use blastoff::CublasContext;
269+
/// # use cust::prelude::*;
270+
/// # use cust::memory::DeviceBox;
271+
/// # use cust::util::SliceExt;
272+
/// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
273+
/// let mut ctx = CublasContext::new()?;
274+
/// let x = [1.0f32, 2.0, 3.0, 4.0].as_dbuf()?;
275+
/// let mut y = [0.0; 4].as_dbuf()?;
276+
///
277+
/// ctx.copy(&stream, x.len(), &x, &mut y)?;
278+
///
279+
/// stream.synchronize()?;
280+
///
281+
/// assert_eq!(x.as_host_vec()?, y.as_host_vec()?);
282+
/// # Ok(())
283+
/// # }
284+
/// ```
285+
pub fn copy<T: Level1>(
286+
&mut self,
287+
stream: &Stream,
288+
n: usize,
289+
x: &impl GpuBuffer<T>,
290+
y: &mut impl GpuBuffer<T>,
291+
) -> Result {
292+
self.copy_strided(stream, n, x, None, y, None)
293+
}
226294
}

crates/blastoff/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//! you will likely need to do some math to any returned indices. For example,
88
//! [`amin`](crate::context::CublasContext::amin) returns a 1-based index.**
99
10+
#![allow(clippy::too_many_arguments)]
11+
1012
pub use cublas_sys as sys;
1113
use num_complex::{Complex32, Complex64};
1214

0 commit comments

Comments
 (0)