@@ -62,7 +62,7 @@ impl CublasContext {
62
62
/// ```
63
63
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
64
64
/// # let _a = cust::quick_init()?;
65
- /// # use blastoff::context:: CublasContext;
65
+ /// # use blastoff::CublasContext;
66
66
/// # use cust::prelude::*;
67
67
/// # use cust::memory::DeviceBox;
68
68
/// # use cust::util::SliceExt;
@@ -124,7 +124,7 @@ impl CublasContext {
124
124
/// ```
125
125
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
126
126
/// # let _a = cust::quick_init()?;
127
- /// # use blastoff::context:: CublasContext;
127
+ /// # use blastoff::CublasContext;
128
128
/// # use cust::prelude::*;
129
129
/// # use cust::memory::DeviceBox;
130
130
/// # use cust::util::SliceExt;
@@ -194,7 +194,7 @@ impl CublasContext {
194
194
/// ```
195
195
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
196
196
/// # let _a = cust::quick_init()?;
197
- /// # use blastoff::context:: CublasContext;
197
+ /// # use blastoff::CublasContext;
198
198
/// # use cust::prelude::*;
199
199
/// # use cust::memory::DeviceBox;
200
200
/// # use cust::util::SliceExt;
@@ -223,4 +223,72 @@ impl CublasContext {
223
223
) -> Result {
224
224
self . axpy_strided ( stream, alpha, n, x, None , y, None )
225
225
}
226
+
227
+ /// Same as [`CublasContext::copy`] but with an explicit stride.
228
+ ///
229
+ /// # Panics
230
+ ///
231
+ /// Panics if the buffers are not long enough for the stride and length requested.
232
+ pub fn copy_strided < T : Level1 > (
233
+ & mut self ,
234
+ stream : & Stream ,
235
+ n : usize ,
236
+ x : & impl GpuBuffer < T > ,
237
+ x_stride : Option < usize > ,
238
+ y : & mut impl GpuBuffer < T > ,
239
+ y_stride : Option < usize > ,
240
+ ) -> Result {
241
+ check_stride ( x, n, x_stride) ;
242
+ check_stride ( y, n, y_stride) ;
243
+
244
+ self . with_stream ( stream, |ctx| unsafe {
245
+ Ok ( T :: copy (
246
+ ctx. raw ,
247
+ x. len ( ) as i32 ,
248
+ x. as_device_ptr ( ) . as_raw ( ) ,
249
+ x_stride. unwrap_or ( 1 ) as i32 ,
250
+ y. as_device_ptr ( ) . as_raw_mut ( ) ,
251
+ y_stride. unwrap_or ( 1 ) as i32 ,
252
+ )
253
+ . to_result ( ) ?)
254
+ } )
255
+ }
256
+
257
+ /// Copies `n` elements from `x` into `y`, overriding any previous data inside `y`.
258
+ ///
259
+ /// # Panics
260
+ ///
261
+ /// Panics if `x` or `y` are not large enough for the requested amount of elements.
262
+ ///
263
+ /// # Example
264
+ ///
265
+ /// ```
266
+ /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
267
+ /// # let _a = cust::quick_init()?;
268
+ /// # use blastoff::CublasContext;
269
+ /// # use cust::prelude::*;
270
+ /// # use cust::memory::DeviceBox;
271
+ /// # use cust::util::SliceExt;
272
+ /// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
273
+ /// let mut ctx = CublasContext::new()?;
274
+ /// let x = [1.0f32, 2.0, 3.0, 4.0].as_dbuf()?;
275
+ /// let mut y = [0.0; 4].as_dbuf()?;
276
+ ///
277
+ /// ctx.copy(&stream, x.len(), &x, &mut y)?;
278
+ ///
279
+ /// stream.synchronize()?;
280
+ ///
281
+ /// assert_eq!(x.as_host_vec()?, y.as_host_vec()?);
282
+ /// # Ok(())
283
+ /// # }
284
+ /// ```
285
+ pub fn copy < T : Level1 > (
286
+ & mut self ,
287
+ stream : & Stream ,
288
+ n : usize ,
289
+ x : & impl GpuBuffer < T > ,
290
+ y : & mut impl GpuBuffer < T > ,
291
+ ) -> Result {
292
+ self . copy_strided ( stream, n, x, None , y, None )
293
+ }
226
294
}
0 commit comments