diff --git a/deps/ReactantExtra/xla_ffi.cpp b/deps/ReactantExtra/xla_ffi.cpp index 96d6819266..abe06b85b7 100644 --- a/deps/ReactantExtra/xla_ffi.cpp +++ b/deps/ReactantExtra/xla_ffi.cpp @@ -132,6 +132,13 @@ ffi::Error Syrk(cublasHandle_t handle, cublasFillMode_t uplo, return ffi::Error::InvalidArgument("Unsupported type for syrk"); } +template +ffi::Error Symm(cublasHandle_t handle, cublasSideMode_t side, + cublasFillMode_t uplo, int m, int n, const T *alpha, const T *A, + int lda, const T *B, int ldb, const T *beta, T *C, int ldc) { + return ffi::Error::InvalidArgument("Unsupported type for symm"); +} + #define SYRK_SPECIALIZATION(T, cublas_func) \ template <> \ ffi::Error Syrk(cublasHandle_t handle, cublasFillMode_t uplo, \ @@ -149,6 +156,24 @@ SYRK_SPECIALIZATION(cuDoubleComplex, cublasZsyrk) #undef SYRK_SPECIALIZATION +#define SYMM_SPECIALIZATION(T, cublas_func) \ + template <> \ + ffi::Error Symm(cublasHandle_t handle, cublasSideMode_t side, \ + cublasFillMode_t uplo, int m, int n, const T *alpha, \ + const T *A, int lda, const T *B, int ldb, const T *beta, \ + T *C, int ldc) { \ + cublasStatus_t status = cublas_func(handle, side, uplo, m, n, alpha, A, \ + lda, B, ldb, beta, C, ldc); \ + return CublasStatusToError(status, #cublas_func); \ + } + +SYMM_SPECIALIZATION(float, cublasSsymm) +SYMM_SPECIALIZATION(double, cublasDsymm) +SYMM_SPECIALIZATION(cuComplex, cublasCsymm) +SYMM_SPECIALIZATION(cuDoubleComplex, cublasZsymm) + +#undef SYMM_SPECIALIZATION + } // namespace blas // Symmetric rank-k update: syrk @@ -310,11 +335,182 @@ XLA_FFI_DEFINE_HANDLER( .Ret() // c_out ); +template +ffi::Error SymmImpl(CUstream stream, bool side_, bool uplo_, ffi::AnyBuffer a, + ffi::AnyBuffer b, const T *alpha, const T *beta, + ffi::Result c_out) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(b.dimensions())); + int a_size = side_ ? cols : rows; // cols if right, rows if left + + FFI_RETURN_IF_ERROR( + CheckShape(a.dimensions(), {batch, a_size, a_size}, "a", "symm")); + // C should have same shape as B + FFI_RETURN_IF_ERROR( + CheckShape(c_out->dimensions(), {batch, rows, cols}, "c_out", "symm")); + + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + + // We flip uplo here because A is passed in row-major format. + // Row-major A is equivalent to A^T in column-major, and since A is + // symmetric, this means we need to swap upper/lower triangular. + cublasFillMode_t uplo = + uplo_ ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + // We can swap side since (A*B)^T = B^T*A, where B^T is also the column-major + // interpretation of B + cublasSideMode_t side = side_ ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; + + const T *a_data = static_cast(a.untyped_data()); + const T *b_data = static_cast(b.untyped_data()); + T *c_out_data = static_cast(c_out->untyped_data()); + + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); + // lda is the leading dimension of a, etc. + int lda = side == CUBLAS_SIDE_LEFT ? n : m; + int ldb = lda; + int ldc = lda; + for (int i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR(blas::Symm(handle.get(), side, uplo, m, n, alpha, + a_data, lda, b_data, ldb, beta, + c_out_data, ldc)); + a_data += lda * lda; + b_data += m * n; + c_out_data += m * n; + } + return ffi::Error::Success(); +} + +template +ffi::Error SymmImpl(CUstream stream, bool side_, bool uplo_, + ffi::AnyBuffer a, ffi::AnyBuffer b, ffi::AnyBuffer c_in, + const T *alpha, const T *beta, + ffi::Result c_out) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(b.dimensions())); + int a_size = side_ ? cols : rows; + + FFI_RETURN_IF_ERROR( + CheckShape(a.dimensions(), {batch, a_size, a_size}, "a", "symm")); + FFI_RETURN_IF_ERROR( + CheckShape(c_out->dimensions(), {batch, rows, cols}, "c_out", "symm")); + + T *c_data = static_cast(c_in.untyped_data()); + T *c_out_data = static_cast(c_out->untyped_data()); + + if (c_data != c_out_data) { + cudaError_t err = cudaMemcpyAsync(c_out_data, c_data, c_in.size_bytes(), + cudaMemcpyDeviceToDevice, stream); + if (err != cudaSuccess) { + return ffi::Error::InvalidArgument(absl::StrFormat( + "cudaMemcpyAsync failed: %s", cudaGetErrorString(err))); + } + } + return SymmImpl(stream, side_, uplo_, a, b, alpha, beta, c_out); +} + +template +ffi::Error +SymmImpl(CUstream stream, bool side, bool uplo, bool use_alpha_attribute, + double alpha_real, double alpha_imag, bool use_beta_attribute, + double beta_real, double beta_imag, ffi::AnyBuffer a, ffi::AnyBuffer b, + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_, ffi::AnyBuffer beta_, + ffi::Result c_out) { + T host_alpha, host_beta; + FFI_RETURN_IF_ERROR(GetHostScalar(stream, use_alpha_attribute, alpha_real, + alpha_imag, alpha_, &host_alpha)); + FFI_RETURN_IF_ERROR(GetHostScalar(stream, use_beta_attribute, beta_real, + beta_imag, beta_, &host_beta)); + return SymmImpl(stream, side, uplo, a, c_in, &host_alpha, &host_beta, + c_out); +} + +template +ffi::Error SymmImpl(CUstream stream, bool side, bool uplo, + bool use_alpha_attribute, double alpha_real, + double alpha_imag, ffi::AnyBuffer a, ffi::AnyBuffer b, + ffi::AnyBuffer alpha_, ffi::Result c_out) { + T host_alpha, host_beta; + FFI_RETURN_IF_ERROR(GetHostScalar(stream, use_alpha_attribute, alpha_real, + alpha_imag, alpha_, &host_alpha)); + FFI_RETURN_IF_ERROR(GetHostScalar(0.0, 0.0, &host_beta)); + return SymmImpl(stream, side, uplo, a, b, &host_alpha, &host_beta, + c_out); +} + +ffi::Error SymmDispatch(CUstream stream, bool side, bool uplo, + bool use_alpha_attribute, double alpha_real, + double alpha_imag, bool use_beta_attribute, + double beta_real, double beta_imag, ffi::AnyBuffer a, + ffi::AnyBuffer b, ffi::AnyBuffer c_in, + ffi::AnyBuffer alpha_, ffi::AnyBuffer beta_, + ffi::Result c_out) { + auto dataType = c_in.element_type(); + SOLVER_BLAS_DISPATCH_IMPL(SymmImpl, stream, side, uplo, + use_alpha_attribute, alpha_real, alpha_imag, + use_beta_attribute, beta_real, beta_imag, a, b, c_in, + alpha_, beta_, c_out); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in symm", absl::FormatStreamed(dataType))); +} + +ffi::Error SymmNoCDispatch(CUstream stream, bool side, bool uplo, + bool use_alpha_attribute, double alpha_real, + double alpha_imag, ffi::AnyBuffer a, ffi::AnyBuffer b, + ffi::AnyBuffer alpha_, + ffi::Result c_out) { + auto dataType = a.element_type(); + SOLVER_BLAS_DISPATCH_IMPL(SymmImpl, stream, side, uplo, + use_alpha_attribute, alpha_real, alpha_imag, a, b, + alpha_, c_out); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in symm", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER( + SymmFfi, SymmDispatch, + xla::ffi::Ffi::Bind() + .Ctx>() + .Attr("side") // side + .Attr("uplo") // uplo + .Attr("use_alpha_attribute") // use_alpha_attribute + .Attr("alpha_real") // alpha_real + .Attr("alpha_imag") // alpha_imag + .Attr("use_beta_attribute") // use_beta_attribute + .Attr("beta_real") // beta_real + .Attr("beta_imag") // beta_imag + .Arg() // a + .Arg() // b + .Arg() // c_in + .Arg() // alpha + .Arg() // beta + .Ret() // c_out +); + +XLA_FFI_DEFINE_HANDLER( + SymmNoCFfi, SymmNoCDispatch, + xla::ffi::Ffi::Bind() + .Ctx>() + .Attr("side") // side + .Attr("uplo") // uplo + .Attr("use_alpha_attribute") // use_alpha_attribute + .Attr("alpha_real") // alpha_real + .Attr("alpha_imag") // alpha_imag + .Arg() // a + .Arg() // b + .Arg() // alpha + .Ret() // c_out +); + void registerReactantXLACUDAFFI() { XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "reactant_cublas_syrk_ffi", "CUDA", SyrkFfi); XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "reactant_cublas_syrk_no_c_ffi", "CUDA", SyrkNoCFfi); + XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "reactant_cublas_symm_ffi", + "CUDA", SymmFfi); + XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "reactant_cublas_symm_no_c_ffi", "CUDA", SymmNoCFfi); } #undef SOLVER_BLAS_DISPATCH_IMPL