Skip to content

Commit 32df37e

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Port symmetric tridiagonal reduction GPU kernel to FFI.
PiperOrigin-RevId: 704382200
1 parent 66b9005 commit 32df37e

File tree

6 files changed

+152
-7
lines changed

6 files changed

+152
-7
lines changed

jaxlib/gpu/gpu_kernels.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,14 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA");
6060
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA",
6161
SyevdFfi);
6262
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA");
63+
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_sytrd_ffi", "CUDA",
64+
SytrdFfi);
6365
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA");
66+
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvd_ffi", "CUDA",
67+
GesvdFfi);
6468
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA");
69+
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvdj_ffi", "CUDA",
70+
GesvdjFfi);
6571

6672
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_cholesky_update_ffi", "CUDA",
6773
CholeskyUpdateFfi);

jaxlib/gpu/solver.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ nb::dict Registrations() {
482482
dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi);
483483
dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi);
484484
dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi);
485+
dict[JAX_GPU_PREFIX "solver_sytrd_ffi"] = EncapsulateFfiHandler(SytrdFfi);
485486

486487
#ifdef JAX_GPU_CUDA
487488
dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi);

jaxlib/gpu/solver_interface.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,34 @@ JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched);
317317

318318
#endif // JAX_GPU_CUDA
319319

320+
// Symmetric tridiagonal reduction: sytrd
321+
322+
#define JAX_GPU_DEFINE_SYTRD(Type, Name) \
323+
template <> \
324+
absl::StatusOr<int> SytrdBufferSize<Type>(gpusolverDnHandle_t handle, \
325+
gpusolverFillMode_t uplo, int n) { \
326+
int lwork; \
327+
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \
328+
handle, uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, \
329+
/*E=*/nullptr, /*tau=*/nullptr, &lwork))); \
330+
return lwork; \
331+
} \
332+
\
333+
template <> \
334+
absl::Status Sytrd<Type>(gpusolverDnHandle_t handle, \
335+
gpusolverFillMode_t uplo, int n, Type *a, \
336+
RealType<Type>::value *d, RealType<Type>::value *e, \
337+
Type *tau, Type *workspace, int lwork, int *info) { \
338+
return JAX_AS_STATUS( \
339+
Name(handle, uplo, n, a, n, d, e, tau, workspace, lwork, info)); \
340+
}
341+
342+
JAX_GPU_DEFINE_SYTRD(float, gpusolverDnSsytrd);
343+
JAX_GPU_DEFINE_SYTRD(double, gpusolverDnDsytrd);
344+
JAX_GPU_DEFINE_SYTRD(gpuComplex, gpusolverDnChetrd);
345+
JAX_GPU_DEFINE_SYTRD(gpuDoubleComplex, gpusolverDnZhetrd);
346+
#undef JAX_GPU_DEFINE_SYTRD
347+
320348
} // namespace solver
321349
} // namespace JAX_GPU_NAMESPACE
322350
} // namespace jax

jaxlib/gpu/solver_interface.h

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdjBufferSize);
188188

189189
#define JAX_GPU_SOLVER_Gesvdj_ARGS(Type, Real) \
190190
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \
191-
Type *a, Real *s, Type *u, Type *v, Type *workspace, \
192-
int lwork, int *info, gesvdjInfo_t params
191+
Type *a, Real *s, Type *u, Type *v, Type *workspace, int lwork, \
192+
int *info, gesvdjInfo_t params
193193
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj);
194194
#undef JAX_GPU_SOLVER_Gesvdj_ARGS
195195

@@ -199,15 +199,28 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj);
199199
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdjBatchedBufferSize);
200200
#undef JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS
201201

202-
#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \
203-
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \
204-
Real *s, Type *u, Type *v, Type *workspace, int lwork, \
205-
int *info, gpuGesvdjInfo_t params, int batch
202+
#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \
203+
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \
204+
Real *s, Type *u, Type *v, Type *workspace, int lwork, int *info, \
205+
gpuGesvdjInfo_t params, int batch
206206
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched);
207207
#undef JAX_GPU_SOLVER_GesvdjBatched_ARGS
208208

209209
#endif // JAX_GPU_CUDA
210210

211+
// Symmetric tridiagonal reduction: sytrd
212+
213+
#define JAX_GPU_SOLVER_SytrdBufferSize_ARGS(Type, ...) \
214+
gpusolverDnHandle_t handle, gpublasFillMode_t uplo, int n
215+
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, SytrdBufferSize);
216+
#undef JAX_GPU_SOLVER_SytrdBufferSize_ARGS
217+
218+
#define JAX_GPU_SOLVER_Sytrd_ARGS(Type, Real) \
219+
gpusolverDnHandle_t handle, gpublasFillMode_t uplo, int n, Type *a, Real *d, \
220+
Real *e, Type *tau, Type *workspace, int lwork, int *info
221+
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Sytrd);
222+
#undef JAX_GPU_SOLVER_Sytrd_ARGS
223+
211224
#undef JAX_GPU_SOLVER_EXPAND_DEFINITION
212225

213226
} // namespace solver

jaxlib/gpu/solver_kernels_ffi.cc

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,8 @@ ffi::Error GesvdjImpl(int64_t batch, int64_t rows, int64_t cols,
915915

916916
auto a_data = static_cast<T*>(a.untyped_data());
917917
auto out_data = static_cast<T*>(out->untyped_data());
918-
auto s_data = static_cast<typename solver::RealType<T>::value*>(s->untyped_data());
918+
auto s_data =
919+
static_cast<typename solver::RealType<T>::value*>(s->untyped_data());
919920
auto u_data = static_cast<T*>(u->untyped_data());
920921
auto v_data = static_cast<T*>(v->untyped_data());
921922
auto info_data = info->typed_data();
@@ -1014,6 +1015,101 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch,
10141015

10151016
#endif // JAX_GPU_CUDA
10161017

1018+
// Symmetric tridiagonal reduction: sytrd
1019+
1020+
template <typename T>
1021+
ffi::Error SytrdImpl(int64_t batch, int64_t size, gpuStream_t stream,
1022+
ffi::ScratchAllocator& scratch, bool lower,
1023+
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
1024+
ffi::Result<ffi::AnyBuffer> d,
1025+
ffi::Result<ffi::AnyBuffer> e,
1026+
ffi::Result<ffi::AnyBuffer> tau,
1027+
ffi::Result<ffi::Buffer<ffi::S32>> info) {
1028+
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(size));
1029+
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
1030+
1031+
gpusolverFillMode_t uplo =
1032+
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
1033+
FFI_ASSIGN_OR_RETURN(int lwork,
1034+
solver::SytrdBufferSize<T>(handle.get(), uplo, n));
1035+
FFI_ASSIGN_OR_RETURN(auto workspace,
1036+
AllocateWorkspace<T>(scratch, lwork, "sytrd"));
1037+
1038+
auto* a_data = static_cast<T*>(a.untyped_data());
1039+
auto* out_data = static_cast<T*>(out->untyped_data());
1040+
auto* d_data =
1041+
static_cast<typename solver::RealType<T>::value*>(d->untyped_data());
1042+
auto* e_data =
1043+
static_cast<typename solver::RealType<T>::value*>(e->untyped_data());
1044+
auto* tau_data = static_cast<T*>(tau->untyped_data());
1045+
auto* info_data = info->typed_data();
1046+
if (a_data != out_data) {
1047+
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
1048+
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
1049+
}
1050+
1051+
int out_step = n * n;
1052+
for (int64_t i = 0; i < batch; ++i) {
1053+
FFI_RETURN_IF_ERROR_STATUS(solver::Sytrd<T>(handle.get(), uplo, n, out_data,
1054+
d_data, e_data, tau_data,
1055+
workspace, lwork, info_data));
1056+
out_data += out_step;
1057+
d_data += n;
1058+
e_data += n - 1;
1059+
tau_data += n - 1;
1060+
++info_data;
1061+
}
1062+
return ffi::Error::Success();
1063+
}
1064+
1065+
ffi::Error SytrdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
1066+
bool lower, ffi::AnyBuffer a,
1067+
ffi::Result<ffi::AnyBuffer> out,
1068+
ffi::Result<ffi::AnyBuffer> d,
1069+
ffi::Result<ffi::AnyBuffer> e,
1070+
ffi::Result<ffi::AnyBuffer> tau,
1071+
ffi::Result<ffi::Buffer<ffi::S32>> info) {
1072+
auto dataType = a.element_type();
1073+
if (out->element_type() != dataType ||
1074+
d->element_type() != ffi::ToReal(dataType) ||
1075+
e->element_type() != ffi::ToReal(dataType) ||
1076+
tau->element_type() != dataType) {
1077+
return ffi::Error::InvalidArgument(
1078+
"The inputs and outputs to sytrd must have the same element type");
1079+
}
1080+
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
1081+
SplitBatch2D(a.dimensions()));
1082+
if (rows != cols) {
1083+
return ffi::Error::InvalidArgument(
1084+
"The input matrix to sytrd must be square");
1085+
}
1086+
FFI_RETURN_IF_ERROR(
1087+
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "sytrd"));
1088+
FFI_RETURN_IF_ERROR(CheckShape(d->dimensions(), {batch, cols}, "d", "sytrd"));
1089+
FFI_RETURN_IF_ERROR(
1090+
CheckShape(e->dimensions(), {batch, cols - 1}, "e", "sytrd"));
1091+
FFI_RETURN_IF_ERROR(
1092+
CheckShape(tau->dimensions(), {batch, cols - 1}, "tau", "sytrd"));
1093+
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "sytrd"));
1094+
SOLVER_DISPATCH_IMPL(SytrdImpl, batch, rows, stream, scratch, lower, a, out,
1095+
d, e, tau, info);
1096+
return ffi::Error::InvalidArgument(absl::StrFormat(
1097+
"Unsupported dtype %s in sytrd", absl::FormatStreamed(dataType)));
1098+
}
1099+
1100+
XLA_FFI_DEFINE_HANDLER_SYMBOL(SytrdFfi, SytrdDispatch,
1101+
ffi::Ffi::Bind()
1102+
.Ctx<ffi::PlatformStream<gpuStream_t>>()
1103+
.Ctx<ffi::ScratchAllocator>()
1104+
.Attr<bool>("lower")
1105+
.Arg<ffi::AnyBuffer>() // a
1106+
.Ret<ffi::AnyBuffer>() // out
1107+
.Ret<ffi::AnyBuffer>() // d
1108+
.Ret<ffi::AnyBuffer>() // e
1109+
.Ret<ffi::AnyBuffer>() // tau
1110+
.Ret<ffi::Buffer<ffi::S32>>() // info
1111+
);
1112+
10171113
#undef SOLVER_DISPATCH_IMPL
10181114
#undef SOLVER_BLAS_DISPATCH_IMPL
10191115

jaxlib/gpu/solver_kernels_ffi.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi);
3636
XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi);
3737
XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi);
3838
XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdFfi);
39+
XLA_FFI_DECLARE_HANDLER_SYMBOL(SytrdFfi);
3940

4041
#ifdef JAX_GPU_CUDA
4142
XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi);

0 commit comments

Comments
 (0)