|
3 | 3 | use crate::{
|
4 | 4 | context::CublasContext,
|
5 | 5 | error::{Error, ToResult},
|
6 |
| - raw::Level1, |
| 6 | + raw::{ComplexLevel1, FloatLevel1, Level1}, |
7 | 7 | BlasDatatype,
|
8 | 8 | };
|
9 | 9 | use cust::memory::{GpuBox, GpuBuffer};
|
@@ -45,7 +45,7 @@ impl CublasContext {
|
45 | 45 | self.with_stream(stream, |ctx| unsafe {
|
46 | 46 | Ok(T::amin(
|
47 | 47 | ctx.raw,
|
48 |
| - x.len() as i32, |
| 48 | + n as i32, |
49 | 49 | x.as_device_ptr().as_raw(),
|
50 | 50 | stride.unwrap_or(1) as i32,
|
51 | 51 | result.as_device_ptr().as_raw_mut(),
|
@@ -107,7 +107,7 @@ impl CublasContext {
|
107 | 107 | self.with_stream(stream, |ctx| unsafe {
|
108 | 108 | Ok(T::amax(
|
109 | 109 | ctx.raw,
|
110 |
| - x.len() as i32, |
| 110 | + n as i32, |
111 | 111 | x.as_device_ptr().as_raw(),
|
112 | 112 | stride.unwrap_or(1) as i32,
|
113 | 113 | result.as_device_ptr().as_raw_mut(),
|
@@ -171,7 +171,7 @@ impl CublasContext {
|
171 | 171 | self.with_stream(stream, |ctx| unsafe {
|
172 | 172 | Ok(T::axpy(
|
173 | 173 | ctx.raw,
|
174 |
| - x.len() as i32, |
| 174 | + n as i32, |
175 | 175 | alpha.as_device_ptr().as_raw(),
|
176 | 176 | x.as_device_ptr().as_raw(),
|
177 | 177 | x_stride.unwrap_or(1) as i32,
|
@@ -244,7 +244,7 @@ impl CublasContext {
|
244 | 244 | self.with_stream(stream, |ctx| unsafe {
|
245 | 245 | Ok(T::copy(
|
246 | 246 | ctx.raw,
|
247 |
| - x.len() as i32, |
| 247 | + n as i32, |
248 | 248 | x.as_device_ptr().as_raw(),
|
249 | 249 | x_stride.unwrap_or(1) as i32,
|
250 | 250 | y.as_device_ptr().as_raw_mut(),
|
@@ -291,4 +291,248 @@ impl CublasContext {
|
291 | 291 | ) -> Result {
|
292 | 292 | self.copy_strided(stream, n, x, None, y, None)
|
293 | 293 | }
|
| 294 | + |
| 295 | + /// Same as [`CublasContext::dot`] but with an explicit stride. |
| 296 | + /// |
| 297 | + /// # Panics |
| 298 | + /// |
| 299 | + /// Panics if the buffers are not long enough for the stride and length requested. |
| 300 | + pub fn dot_strided<T: FloatLevel1>( |
| 301 | + &mut self, |
| 302 | + stream: &Stream, |
| 303 | + n: usize, |
| 304 | + x: &impl GpuBuffer<T>, |
| 305 | + x_stride: Option<usize>, |
| 306 | + y: &impl GpuBuffer<T>, |
| 307 | + y_stride: Option<usize>, |
| 308 | + result: &mut impl GpuBox<T>, |
| 309 | + ) -> Result { |
| 310 | + check_stride(x, n, x_stride); |
| 311 | + check_stride(y, n, y_stride); |
| 312 | + |
| 313 | + self.with_stream(stream, |ctx| unsafe { |
| 314 | + Ok(T::dot( |
| 315 | + ctx.raw, |
| 316 | + n as i32, |
| 317 | + x.as_device_ptr().as_raw(), |
| 318 | + x_stride.unwrap_or(1) as i32, |
| 319 | + y.as_device_ptr().as_raw(), |
| 320 | + y_stride.unwrap_or(1) as i32, |
| 321 | + result.as_device_ptr().as_raw_mut(), |
| 322 | + ) |
| 323 | + .to_result()?) |
| 324 | + }) |
| 325 | + } |
| 326 | + |
| 327 | + /// Computes the dot product of two vectors: |
| 328 | + /// |
| 329 | + /// $$ |
| 330 | + /// \sum^n_{i=1} \boldsymbol{x}_i * \boldsymbol{y}_i |
| 331 | + /// $$ |
| 332 | + /// |
| 333 | + /// # Panics |
| 334 | + /// |
| 335 | + /// Panics if the buffers are not long enough for the length requested. |
| 336 | + /// |
| 337 | + /// # Example |
| 338 | + /// |
| 339 | + /// ``` |
| 340 | + /// # fn main() -> Result<(), Box<dyn std::error::Error>> { |
| 341 | + /// # let _a = cust::quick_init()?; |
| 342 | + /// # use blastoff::CublasContext; |
| 343 | + /// # use cust::prelude::*; |
| 344 | + /// # use cust::memory::DeviceBox; |
| 345 | + /// # use cust::util::SliceExt; |
| 346 | + /// # let stream = Stream::new(StreamFlags::DEFAULT, None)?; |
| 347 | + /// let mut ctx = CublasContext::new()?; |
| 348 | + /// let x = [1.0f32, 2.0, 3.0, 4.0].as_dbuf()?; |
| 349 | + /// let y = [1.0f32, 2.0, 3.0, 4.0].as_dbuf()?; |
| 350 | + /// let mut result = DeviceBox::new(&0.0)?; |
| 351 | + /// |
| 352 | + /// ctx.dot(&stream, x.len(), &x, &y, &mut result)?; |
| 353 | + /// |
| 354 | + /// stream.synchronize()?; |
| 355 | + /// |
| 356 | + /// assert_eq!(result.as_host_value()?, 30.0); |
| 357 | + /// # Ok(()) |
| 358 | + /// # } |
| 359 | + /// ``` |
| 360 | + pub fn dot<T: FloatLevel1>( |
| 361 | + &mut self, |
| 362 | + stream: &Stream, |
| 363 | + n: usize, |
| 364 | + x: &impl GpuBuffer<T>, |
| 365 | + y: &impl GpuBuffer<T>, |
| 366 | + result: &mut impl GpuBox<T>, |
| 367 | + ) -> Result { |
| 368 | + self.dot_strided(stream, n, x, None, y, None, result) |
| 369 | + } |
| 370 | + |
| 371 | + /// Same as [`CublasContext::dotu`] but with an explicit stride. |
| 372 | + /// |
| 373 | + /// # Panics |
| 374 | + /// |
| 375 | + /// Panics if the buffers are not long enough for the stride and length requested. |
| 376 | + pub fn dotu_strided<T: ComplexLevel1>( |
| 377 | + &mut self, |
| 378 | + stream: &Stream, |
| 379 | + n: usize, |
| 380 | + x: &impl GpuBuffer<T>, |
| 381 | + x_stride: Option<usize>, |
| 382 | + y: &impl GpuBuffer<T>, |
| 383 | + y_stride: Option<usize>, |
| 384 | + result: &mut impl GpuBox<T>, |
| 385 | + ) -> Result { |
| 386 | + check_stride(x, n, x_stride); |
| 387 | + check_stride(y, n, y_stride); |
| 388 | + |
| 389 | + self.with_stream(stream, |ctx| unsafe { |
| 390 | + Ok(T::dotu( |
| 391 | + ctx.raw, |
| 392 | + n as i32, |
| 393 | + x.as_device_ptr().as_raw(), |
| 394 | + x_stride.unwrap_or(1) as i32, |
| 395 | + y.as_device_ptr().as_raw(), |
| 396 | + y_stride.unwrap_or(1) as i32, |
| 397 | + result.as_device_ptr().as_raw_mut(), |
| 398 | + ) |
| 399 | + .to_result()?) |
| 400 | + }) |
| 401 | + } |
| 402 | + |
| 403 | + /// Computes the unconjugated dot product of two vectors of complex numbers. |
| 404 | + /// |
| 405 | + /// # Panics |
| 406 | + /// |
| 407 | + /// Panics if the buffers are not long enough for the length requested. |
| 408 | + pub fn dotu<T: ComplexLevel1>( |
| 409 | + &mut self, |
| 410 | + stream: &Stream, |
| 411 | + n: usize, |
| 412 | + x: &impl GpuBuffer<T>, |
| 413 | + y: &impl GpuBuffer<T>, |
| 414 | + result: &mut impl GpuBox<T>, |
| 415 | + ) -> Result { |
| 416 | + self.dotu_strided(stream, n, x, None, y, None, result) |
| 417 | + } |
| 418 | + |
| 419 | + /// Same as [`CublasContext::dotc`] but with an explicit stride. |
| 420 | + /// |
| 421 | + /// # Panics |
| 422 | + /// |
| 423 | + /// Panics if the buffers are not long enough for the stride and length requested. |
| 424 | + pub fn dotc_strided<T: ComplexLevel1>( |
| 425 | + &mut self, |
| 426 | + stream: &Stream, |
| 427 | + n: usize, |
| 428 | + x: &impl GpuBuffer<T>, |
| 429 | + x_stride: Option<usize>, |
| 430 | + y: &impl GpuBuffer<T>, |
| 431 | + y_stride: Option<usize>, |
| 432 | + result: &mut impl GpuBox<T>, |
| 433 | + ) -> Result { |
| 434 | + check_stride(x, n, x_stride); |
| 435 | + check_stride(y, n, y_stride); |
| 436 | + |
| 437 | + self.with_stream(stream, |ctx| unsafe { |
| 438 | + Ok(T::dotc( |
| 439 | + ctx.raw, |
| 440 | + n as i32, |
| 441 | + x.as_device_ptr().as_raw(), |
| 442 | + x_stride.unwrap_or(1) as i32, |
| 443 | + y.as_device_ptr().as_raw(), |
| 444 | + y_stride.unwrap_or(1) as i32, |
| 445 | + result.as_device_ptr().as_raw_mut(), |
| 446 | + ) |
| 447 | + .to_result()?) |
| 448 | + }) |
| 449 | + } |
| 450 | + |
| 451 | + /// Computes the conjugated dot product of two vectors of complex numbers. |
| 452 | + /// |
| 453 | + /// # Panics |
| 454 | + /// |
| 455 | + /// Panics if the buffers are not long enough for the length requested. |
| 456 | + pub fn dotc<T: ComplexLevel1>( |
| 457 | + &mut self, |
| 458 | + stream: &Stream, |
| 459 | + n: usize, |
| 460 | + x: &impl GpuBuffer<T>, |
| 461 | + y: &impl GpuBuffer<T>, |
| 462 | + result: &mut impl GpuBox<T>, |
| 463 | + ) -> Result { |
| 464 | + self.dotc_strided(stream, n, x, None, y, None, result) |
| 465 | + } |
| 466 | + |
| 467 | + /// Same as [`CublasContext::nrm2`] but with an explicit stride. |
| 468 | + /// |
| 469 | + /// # Panics |
| 470 | + /// |
| 471 | + /// Panics if the buffers are not long enough for the stride and length requested. |
| 472 | + pub fn nrm2_strided<T: Level1>( |
| 473 | + &mut self, |
| 474 | + stream: &Stream, |
| 475 | + n: usize, |
| 476 | + x: &impl GpuBuffer<T>, |
| 477 | + x_stride: Option<usize>, |
| 478 | + result: &mut impl GpuBox<T::FloatTy>, |
| 479 | + ) -> Result { |
| 480 | + check_stride(x, n, x_stride); |
| 481 | + |
| 482 | + self.with_stream(stream, |ctx| unsafe { |
| 483 | + Ok(T::nrm2( |
| 484 | + ctx.raw, |
| 485 | + n as i32, |
| 486 | + x.as_device_ptr().as_raw(), |
| 487 | + x_stride.unwrap_or(1) as i32, |
| 488 | + result.as_device_ptr().as_raw_mut(), |
| 489 | + ) |
| 490 | + .to_result()?) |
| 491 | + }) |
| 492 | + } |
| 493 | + |
| 494 | + /// Computes the euclidian norm of a vector, in other words, the square root of |
| 495 | + /// the sum of the squares of each element in `x`: |
| 496 | + /// |
| 497 | + /// $$ |
| 498 | + /// \sqrt{\sum_{i=1}^n (\boldsymbol{x}_i^2)} |
| 499 | + /// $$ |
| 500 | + /// |
| 501 | + /// # Panics |
| 502 | + /// |
| 503 | + /// Panics if `x` is not large enough for the requested length `n`. |
| 504 | + /// |
| 505 | + /// # Example |
| 506 | + /// |
| 507 | + /// ``` |
| 508 | + /// # fn main() -> Result<(), Box<dyn std::error::Error>> { |
| 509 | + /// # let _a = cust::quick_init()?; |
| 510 | + /// # use blastoff::CublasContext; |
| 511 | + /// # use cust::prelude::*; |
| 512 | + /// # use cust::memory::DeviceBox; |
| 513 | + /// # use cust::util::SliceExt; |
| 514 | + /// # let stream = Stream::new(StreamFlags::DEFAULT, None)?; |
| 515 | + /// let mut ctx = CublasContext::new()?; |
| 516 | + /// let x = [2.0f32; 4].as_dbuf()?; |
| 517 | + /// let mut result = DeviceBox::new(&0.0f32)?; |
| 518 | + /// |
| 519 | + /// ctx.nrm2(&stream, x.len(), &x, &mut result)?; |
| 520 | + /// |
| 521 | + /// stream.synchronize()?; |
| 522 | + /// |
| 523 | + /// let result = result.as_host_value()?; |
| 524 | + /// // float weirdness |
| 525 | + /// assert!(result >= 3.9 && result <= 4.0); |
| 526 | + /// # Ok(()) |
| 527 | + /// # } |
| 528 | + /// ``` |
| 529 | + pub fn nrm2<T: Level1>( |
| 530 | + &mut self, |
| 531 | + stream: &Stream, |
| 532 | + n: usize, |
| 533 | + x: &impl GpuBuffer<T>, |
| 534 | + result: &mut impl GpuBox<T::FloatTy>, |
| 535 | + ) -> Result { |
| 536 | + self.nrm2_strided(stream, n, x, None, result) |
| 537 | + } |
294 | 538 | }
|
0 commit comments