@@ -915,7 +915,8 @@ ffi::Error GesvdjImpl(int64_t batch, int64_t rows, int64_t cols,
915915
916916 auto a_data = static_cast <T*>(a.untyped_data ());
917917 auto out_data = static_cast <T*>(out->untyped_data ());
918- auto s_data = static_cast <typename solver::RealType<T>::value*>(s->untyped_data ());
918+ auto s_data =
919+ static_cast <typename solver::RealType<T>::value*>(s->untyped_data ());
919920 auto u_data = static_cast <T*>(u->untyped_data ());
920921 auto v_data = static_cast <T*>(v->untyped_data ());
921922 auto info_data = info->typed_data ();
@@ -1014,6 +1015,101 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch,
10141015
10151016#endif // JAX_GPU_CUDA
10161017
1018+ // Symmetric tridiagonal reduction: sytrd
1019+
1020+ template <typename T>
1021+ ffi::Error SytrdImpl (int64_t batch, int64_t size, gpuStream_t stream,
1022+ ffi::ScratchAllocator& scratch, bool lower,
1023+ ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
1024+ ffi::Result<ffi::AnyBuffer> d,
1025+ ffi::Result<ffi::AnyBuffer> e,
1026+ ffi::Result<ffi::AnyBuffer> tau,
1027+ ffi::Result<ffi::Buffer<ffi::S32>> info) {
1028+ FFI_ASSIGN_OR_RETURN (auto n, MaybeCastNoOverflow<int >(size));
1029+ FFI_ASSIGN_OR_RETURN (auto handle, SolverHandlePool::Borrow (stream));
1030+
1031+ gpusolverFillMode_t uplo =
1032+ lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
1033+ FFI_ASSIGN_OR_RETURN (int lwork,
1034+ solver::SytrdBufferSize<T>(handle.get (), uplo, n));
1035+ FFI_ASSIGN_OR_RETURN (auto workspace,
1036+ AllocateWorkspace<T>(scratch, lwork, " sytrd" ));
1037+
1038+ auto * a_data = static_cast <T*>(a.untyped_data ());
1039+ auto * out_data = static_cast <T*>(out->untyped_data ());
1040+ auto * d_data =
1041+ static_cast <typename solver::RealType<T>::value*>(d->untyped_data ());
1042+ auto * e_data =
1043+ static_cast <typename solver::RealType<T>::value*>(e->untyped_data ());
1044+ auto * tau_data = static_cast <T*>(tau->untyped_data ());
1045+ auto * info_data = info->typed_data ();
1046+ if (a_data != out_data) {
1047+ JAX_FFI_RETURN_IF_GPU_ERROR (gpuMemcpyAsync (
1048+ out_data, a_data, a.size_bytes (), gpuMemcpyDeviceToDevice, stream));
1049+ }
1050+
1051+ int out_step = n * n;
1052+ for (int64_t i = 0 ; i < batch; ++i) {
1053+ FFI_RETURN_IF_ERROR_STATUS (solver::Sytrd<T>(handle.get (), uplo, n, out_data,
1054+ d_data, e_data, tau_data,
1055+ workspace, lwork, info_data));
1056+ out_data += out_step;
1057+ d_data += n;
1058+ e_data += n - 1 ;
1059+ tau_data += n - 1 ;
1060+ ++info_data;
1061+ }
1062+ return ffi::Error::Success ();
1063+ }
1064+
1065+ ffi::Error SytrdDispatch (gpuStream_t stream, ffi::ScratchAllocator scratch,
1066+ bool lower, ffi::AnyBuffer a,
1067+ ffi::Result<ffi::AnyBuffer> out,
1068+ ffi::Result<ffi::AnyBuffer> d,
1069+ ffi::Result<ffi::AnyBuffer> e,
1070+ ffi::Result<ffi::AnyBuffer> tau,
1071+ ffi::Result<ffi::Buffer<ffi::S32>> info) {
1072+ auto dataType = a.element_type ();
1073+ if (out->element_type () != dataType ||
1074+ d->element_type () != ffi::ToReal (dataType) ||
1075+ e->element_type () != ffi::ToReal (dataType) ||
1076+ tau->element_type () != dataType) {
1077+ return ffi::Error::InvalidArgument (
1078+ " The inputs and outputs to sytrd must have the same element type" );
1079+ }
1080+ FFI_ASSIGN_OR_RETURN ((auto [batch, rows, cols]),
1081+ SplitBatch2D (a.dimensions ()));
1082+ if (rows != cols) {
1083+ return ffi::Error::InvalidArgument (
1084+ " The input matrix to sytrd must be square" );
1085+ }
1086+ FFI_RETURN_IF_ERROR (
1087+ CheckShape (out->dimensions (), {batch, rows, cols}, " out" , " sytrd" ));
1088+ FFI_RETURN_IF_ERROR (CheckShape (d->dimensions (), {batch, cols}, " d" , " sytrd" ));
1089+ FFI_RETURN_IF_ERROR (
1090+ CheckShape (e->dimensions (), {batch, cols - 1 }, " e" , " sytrd" ));
1091+ FFI_RETURN_IF_ERROR (
1092+ CheckShape (tau->dimensions (), {batch, cols - 1 }, " tau" , " sytrd" ));
1093+ FFI_RETURN_IF_ERROR (CheckShape (info->dimensions (), batch, " info" , " sytrd" ));
1094+ SOLVER_DISPATCH_IMPL (SytrdImpl, batch, rows, stream, scratch, lower, a, out,
1095+ d, e, tau, info);
1096+ return ffi::Error::InvalidArgument (absl::StrFormat (
1097+ " Unsupported dtype %s in sytrd" , absl::FormatStreamed (dataType)));
1098+ }
1099+
1100+ XLA_FFI_DEFINE_HANDLER_SYMBOL (SytrdFfi, SytrdDispatch,
1101+ ffi::Ffi::Bind ()
1102+ .Ctx<ffi::PlatformStream<gpuStream_t>>()
1103+ .Ctx<ffi::ScratchAllocator>()
1104+ .Attr<bool>(" lower" )
1105+ .Arg<ffi::AnyBuffer>() // a
1106+ .Ret<ffi::AnyBuffer>() // out
1107+ .Ret<ffi::AnyBuffer>() // d
1108+ .Ret<ffi::AnyBuffer>() // e
1109+ .Ret<ffi::AnyBuffer>() // tau
1110+ .Ret<ffi::Buffer<ffi::S32>>() // info
1111+ );
1112+
10171113#undef SOLVER_DISPATCH_IMPL
10181114#undef SOLVER_BLAS_DISPATCH_IMPL
10191115
0 commit comments