Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions deps/ReactantExtra/xla_ffi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ ffi::Error Syrk(cublasHandle_t handle, cublasFillMode_t uplo,
return ffi::Error::InvalidArgument("Unsupported type for syrk");
}

template <typename T>
ffi::Error Symm(cublasHandle_t handle, cublasSideMode_t side,
cublasFillMode_t uplo, int m, int n, const T *alpha, const T *A,
int lda, const T *B, int ldb, const T *beta, T *C, int ldc) {
return ffi::Error::InvalidArgument("Unsupported type for symm");
}

#define SYRK_SPECIALIZATION(T, cublas_func) \
template <> \
ffi::Error Syrk<T>(cublasHandle_t handle, cublasFillMode_t uplo, \
Expand All @@ -149,6 +156,24 @@ SYRK_SPECIALIZATION(cuDoubleComplex, cublasZsyrk)

#undef SYRK_SPECIALIZATION

#define SYMM_SPECIALIZATION(T, cublas_func) \
template <> \
ffi::Error Symm<T>(cublasHandle_t handle, cublasSideMode_t side, \
cublasFillMode_t uplo, int m, int n, const T *alpha, \
const T *A, int lda, const T *B, int ldb, const T *beta, \
T *C, int ldc) { \
cublasStatus_t status = cublas_func(handle, side, uplo, m, n, alpha, A, \
lda, B, ldb, beta, C, ldc); \
return CublasStatusToError(status, #cublas_func); \
}

SYMM_SPECIALIZATION(float, cublasSsymm)
SYMM_SPECIALIZATION(double, cublasDsymm)
SYMM_SPECIALIZATION(cuComplex, cublasCsymm)
SYMM_SPECIALIZATION(cuDoubleComplex, cublasZsymm)

#undef SYMM_SPECIALIZATION

} // namespace blas

// Symmetric rank-k update: syrk
Expand Down Expand Up @@ -310,11 +335,182 @@ XLA_FFI_DEFINE_HANDLER(
.Ret<ffi::AnyBuffer>() // c_out
);

template <typename T>
ffi::Error SymmImpl(CUstream stream, bool side_, bool uplo_, ffi::AnyBuffer a,
ffi::AnyBuffer b, const T *alpha, const T *beta,
ffi::Result<ffi::AnyBuffer> c_out) {
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(b.dimensions()));
int a_size = side_ ? cols : rows; // cols if right, rows if left

FFI_RETURN_IF_ERROR(
CheckShape(a.dimensions(), {batch, a_size, a_size}, "a", "symm"));
// C should have same shape as B
FFI_RETURN_IF_ERROR(
CheckShape(c_out->dimensions(), {batch, rows, cols}, "c_out", "symm"));

FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));

// We flip uplo here because A is passed in row-major format.
// Row-major A is equivalent to A^T in column-major, and since A is
// symmetric, this means we need to swap upper/lower triangular.
cublasFillMode_t uplo =
uplo_ ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
// We can swap side since (A*B)^T = B^T*A, where B^T is also the column-major
// interpretation of B
cublasSideMode_t side = side_ ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;

const T *a_data = static_cast<const T *>(a.untyped_data());
const T *b_data = static_cast<const T *>(b.untyped_data());
T *c_out_data = static_cast<T *>(c_out->untyped_data());

FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
// lda is the leading dimension of a, etc.
int lda = side == CUBLAS_SIDE_LEFT ? n : m;
int ldb = lda;
int ldc = lda;
for (int i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR(blas::Symm<T>(handle.get(), side, uplo, m, n, alpha,
a_data, lda, b_data, ldb, beta,
c_out_data, ldc));
a_data += lda * lda;
b_data += m * n;
c_out_data += m * n;
}
return ffi::Error::Success();
}

template <typename T>
ffi::Error SymmImpl(CUstream stream, bool side_, bool uplo_,
ffi::AnyBuffer a, ffi::AnyBuffer b, ffi::AnyBuffer c_in,
const T *alpha, const T *beta,
ffi::Result<ffi::AnyBuffer> c_out) {
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(b.dimensions()));
int a_size = side_ ? cols : rows;

FFI_RETURN_IF_ERROR(
CheckShape(a.dimensions(), {batch, a_size, a_size}, "a", "symm"));
FFI_RETURN_IF_ERROR(
CheckShape(c_out->dimensions(), {batch, rows, cols}, "c_out", "symm"));

T *c_data = static_cast<T *>(c_in.untyped_data());
T *c_out_data = static_cast<T *>(c_out->untyped_data());

if (c_data != c_out_data) {
cudaError_t err = cudaMemcpyAsync(c_out_data, c_data, c_in.size_bytes(),
cudaMemcpyDeviceToDevice, stream);
if (err != cudaSuccess) {
return ffi::Error::InvalidArgument(absl::StrFormat(
"cudaMemcpyAsync failed: %s", cudaGetErrorString(err)));
}
}
return SymmImpl<T>(stream, side_, uplo_, a, b, alpha, beta, c_out);
}

template <typename T>
ffi::Error
SymmImpl(CUstream stream, bool side, bool uplo, bool use_alpha_attribute,
double alpha_real, double alpha_imag, bool use_beta_attribute,
double beta_real, double beta_imag, ffi::AnyBuffer a, ffi::AnyBuffer b,
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_, ffi::AnyBuffer beta_,
ffi::Result<ffi::AnyBuffer> c_out) {
T host_alpha, host_beta;
FFI_RETURN_IF_ERROR(GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
alpha_imag, alpha_, &host_alpha));
FFI_RETURN_IF_ERROR(GetHostScalar<T>(stream, use_beta_attribute, beta_real,
beta_imag, beta_, &host_beta));
return SymmImpl<T>(stream, side, uplo, a, c_in, &host_alpha, &host_beta,
c_out);
}

template <typename T>
ffi::Error SymmImpl(CUstream stream, bool side, bool uplo,
bool use_alpha_attribute, double alpha_real,
double alpha_imag, ffi::AnyBuffer a, ffi::AnyBuffer b,
ffi::AnyBuffer alpha_, ffi::Result<ffi::AnyBuffer> c_out) {
T host_alpha, host_beta;
FFI_RETURN_IF_ERROR(GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
alpha_imag, alpha_, &host_alpha));
FFI_RETURN_IF_ERROR(GetHostScalar<T>(0.0, 0.0, &host_beta));
return SymmImpl<T>(stream, side, uplo, a, b, &host_alpha, &host_beta,
c_out);
}

ffi::Error SymmDispatch(CUstream stream, bool side, bool uplo,
bool use_alpha_attribute, double alpha_real,
double alpha_imag, bool use_beta_attribute,
double beta_real, double beta_imag, ffi::AnyBuffer a,
ffi::AnyBuffer b, ffi::AnyBuffer c_in,
ffi::AnyBuffer alpha_, ffi::AnyBuffer beta_,
ffi::Result<ffi::AnyBuffer> c_out) {
auto dataType = c_in.element_type();
SOLVER_BLAS_DISPATCH_IMPL(SymmImpl, stream, side, uplo,
use_alpha_attribute, alpha_real, alpha_imag,
use_beta_attribute, beta_real, beta_imag, a, b, c_in,
alpha_, beta_, c_out);
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in symm", absl::FormatStreamed(dataType)));
}

ffi::Error SymmNoCDispatch(CUstream stream, bool side, bool uplo,
bool use_alpha_attribute, double alpha_real,
double alpha_imag, ffi::AnyBuffer a, ffi::AnyBuffer b,
ffi::AnyBuffer alpha_,
ffi::Result<ffi::AnyBuffer> c_out) {
auto dataType = a.element_type();
SOLVER_BLAS_DISPATCH_IMPL(SymmImpl, stream, side, uplo,
use_alpha_attribute, alpha_real, alpha_imag, a, b,
alpha_, c_out);
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in symm", absl::FormatStreamed(dataType)));
}

XLA_FFI_DEFINE_HANDLER(
SymmFfi, SymmDispatch,
xla::ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<CUstream>>()
.Attr<bool>("side") // side
.Attr<bool>("uplo") // uplo
.Attr<bool>("use_alpha_attribute") // use_alpha_attribute
.Attr<double>("alpha_real") // alpha_real
.Attr<double>("alpha_imag") // alpha_imag
.Attr<bool>("use_beta_attribute") // use_beta_attribute
.Attr<double>("beta_real") // beta_real
.Attr<double>("beta_imag") // beta_imag
.Arg<ffi::AnyBuffer>() // a
.Arg<ffi::AnyBuffer>() // b
.Arg<ffi::AnyBuffer>() // c_in
.Arg<ffi::AnyBuffer>() // alpha
.Arg<ffi::AnyBuffer>() // beta
.Ret<ffi::AnyBuffer>() // c_out
);

XLA_FFI_DEFINE_HANDLER(
SymmNoCFfi, SymmNoCDispatch,
xla::ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<CUstream>>()
.Attr<bool>("side") // side
.Attr<bool>("uplo") // uplo
.Attr<bool>("use_alpha_attribute") // use_alpha_attribute
.Attr<double>("alpha_real") // alpha_real
.Attr<double>("alpha_imag") // alpha_imag
.Arg<ffi::AnyBuffer>() // a
.Arg<ffi::AnyBuffer>() // b
.Arg<ffi::AnyBuffer>() // alpha
.Ret<ffi::AnyBuffer>() // c_out
);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a registration to the functions in registerReactantXLACUDAFFI

void registerReactantXLACUDAFFI() {
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "reactant_cublas_syrk_ffi",
"CUDA", SyrkFfi);
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(),
"reactant_cublas_syrk_no_c_ffi", "CUDA", SyrkNoCFfi);
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "reactant_cublas_symm_ffi",
"CUDA", SymmFfi);
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(),
"reactant_cublas_symm_no_c_ffi", "CUDA", SymmNoCFfi);
}

#undef SOLVER_BLAS_DISPATCH_IMPL
Expand Down
Loading