|
| 1 | +use crate::{ |
| 2 | + context::CublasContext, |
| 3 | + error::{Error, ToResult}, |
| 4 | + raw::GemmOps, |
| 5 | + GemmDatatype, MatrixOp, |
| 6 | +}; |
| 7 | +use cust::memory::{GpuBox, GpuBuffer}; |
| 8 | +use cust::stream::Stream; |
| 9 | + |
| 10 | +type Result<T = (), E = Error> = std::result::Result<T, E>; |
| 11 | + |
| 12 | +#[track_caller] |
| 13 | +fn check_gemm<T: GemmDatatype + GemmOps>( |
| 14 | + m: usize, |
| 15 | + n: usize, |
| 16 | + k: usize, |
| 17 | + a: &impl GpuBuffer<T>, |
| 18 | + lda: usize, |
| 19 | + op_a: MatrixOp, |
| 20 | + b: &impl GpuBuffer<T>, |
| 21 | + ldb: usize, |
| 22 | + op_b: MatrixOp, |
| 23 | + c: &mut impl GpuBuffer<T>, |
| 24 | + ldc: usize, |
| 25 | +) { |
| 26 | + assert!(m > 0 && n > 0 && k > 0, "m, n, and k must be at least 1"); |
| 27 | + |
| 28 | + if op_a == MatrixOp::None { |
| 29 | + assert!(lda >= m, "lda must be at least m if op_a is None"); |
| 30 | + |
| 31 | + assert!( |
| 32 | + a.len() >= lda * k, |
| 33 | + "matrix A's length must be at least lda * k" |
| 34 | + ); |
| 35 | + } else { |
| 36 | + assert!(lda >= k, "lda must be at least k if op_a is None"); |
| 37 | + |
| 38 | + assert!( |
| 39 | + a.len() >= lda * m, |
| 40 | + "matrix A's length must be at least lda * m" |
| 41 | + ); |
| 42 | + } |
| 43 | + |
| 44 | + if op_b == MatrixOp::None { |
| 45 | + assert!(ldb >= k, "ldb must be at least k if op_b is None"); |
| 46 | + |
| 47 | + assert!( |
| 48 | + b.len() >= ldb * n, |
| 49 | + "matrix B's length must be at least ldb * n" |
| 50 | + ); |
| 51 | + } else { |
| 52 | + assert!(ldb >= n, "ldb must be at least n if op_b is None"); |
| 53 | + |
| 54 | + assert!( |
| 55 | + a.len() >= ldb * k, |
| 56 | + "matrix B's length must be at least ldb * k" |
| 57 | + ); |
| 58 | + } |
| 59 | + |
| 60 | + assert!(ldc >= m, "ldc must be at least m"); |
| 61 | + |
| 62 | + assert!( |
| 63 | + c.len() >= ldc * n, |
| 64 | + "matrix C's length must be at least ldc * n" |
| 65 | + ); |
| 66 | +} |
| 67 | + |
| 68 | +impl CublasContext { |
| 69 | + /// Generic Matrix Multiplication. |
| 70 | + /// |
| 71 | + /// # Panics |
| 72 | + /// |
| 73 | + /// Panics if any of the following conditions are not met: |
| 74 | + /// - `m > 0 && n > 0 && k > 0` |
| 75 | + /// - `lda >= m` if `op_a == MatrixOp::None` |
| 76 | + /// - `a.len() >= lda * k` if `op_a == MatrixOp::None` |
| 77 | + /// - `lda >= k` if `op_a == MatrixOp::Transpose` or `MatrixOp::ConjugateTranspose` |
| 78 | + /// - `a.len() >= lda * m` if `op_a == MatrixOp::Transpose` or `MatrixOp::ConjugateTranspose` |
| 79 | + /// - `ldb >= k` if `op_b == MatrixOp::None` |
| 80 | + /// - `b.len() >= ldb * n` if `op_b == MatrixOp::None` |
| 81 | + /// - `ldb >= n` if `op_b == MatrixOp::Transpose` or `MatrixOp::ConjugateTranspose` |
| 82 | + /// - `b.len() >= ldb * k` if `op_b == MatrixOp::Transpose` or `MatrixOp::ConjugateTranspose` |
| 83 | + /// - `ldc >= m` |
| 84 | + /// - `c.len() >= ldc * n` |
| 85 | + /// |
| 86 | + /// # Errors |
| 87 | + /// |
| 88 | + /// Returns an error if the kernel execution failed or the selected precision is `half` and the device does not support half precision. |
| 89 | + #[track_caller] |
| 90 | + pub fn gemm<T: GemmDatatype + GemmOps>( |
| 91 | + &mut self, |
| 92 | + stream: &Stream, |
| 93 | + m: usize, |
| 94 | + n: usize, |
| 95 | + k: usize, |
| 96 | + alpha: &impl GpuBox<T>, |
| 97 | + a: &impl GpuBuffer<T>, |
| 98 | + lda: usize, |
| 99 | + op_a: MatrixOp, |
| 100 | + beta: &impl GpuBox<T>, |
| 101 | + b: &impl GpuBuffer<T>, |
| 102 | + ldb: usize, |
| 103 | + op_b: MatrixOp, |
| 104 | + c: &mut impl GpuBuffer<T>, |
| 105 | + ldc: usize, |
| 106 | + ) -> Result { |
| 107 | + check_gemm(m, n, k, a, lda, op_a, b, ldb, op_b, c, ldc); |
| 108 | + |
| 109 | + let transa = op_a.to_raw(); |
| 110 | + let transb = op_b.to_raw(); |
| 111 | + |
| 112 | + self.with_stream(stream, |ctx| unsafe { |
| 113 | + Ok(T::gemm( |
| 114 | + ctx.raw, |
| 115 | + transa, |
| 116 | + transb, |
| 117 | + m as i32, |
| 118 | + n as i32, |
| 119 | + k as i32, |
| 120 | + alpha.as_device_ptr().as_ptr(), |
| 121 | + a.as_device_ptr().as_ptr(), |
| 122 | + lda as i32, |
| 123 | + b.as_device_ptr().as_ptr(), |
| 124 | + ldb as i32, |
| 125 | + beta.as_device_ptr().as_ptr(), |
| 126 | + c.as_device_ptr().as_mut_ptr(), |
| 127 | + ldc as i32, |
| 128 | + ) |
| 129 | + .to_result()?) |
| 130 | + }) |
| 131 | + } |
| 132 | +} |
0 commit comments