diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 7e4bea683926..04e70518c0fa 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1070,6 +1070,8 @@ def _check_lowering(lowering) -> None: # qr on GPU "cusolver_geqrf_ffi", "cusolver_orgqr_ffi", "hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", + # cholesky on GPU + "cusolver_potrf_ffi", "hipsolver_potrf_ffi", # eigh on GPU "cusolver_syevd_ffi", "hipsolver_syevd_ffi", # svd on GPU diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 5b3f03fd3a16..ef39cf51f44a 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -45,8 +45,8 @@ from jax._src.lib import gpu_linalg from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse -from jax._src.lib import version as jaxlib_version from jax._src.lib import lapack +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -865,11 +865,31 @@ def _cholesky_cpu_lowering(ctx, operand): return [_replace_not_ok_with_nan(ctx, batch_dims, ok, result, out_aval)] +def _cholesky_gpu_lowering(ctx, operand, *, target_name_prefix): + # TODO(phawkins): remove forward compat path after Nov 10, 2025. + if ctx.is_forward_compat(): + return _cholesky_lowering(ctx, operand) + operand_aval, = ctx.avals_in + out_aval, = ctx.avals_out + batch_dims = operand_aval.shape[:-2] + info_aval = ShapedArray(batch_dims, np.int32) + rule = _linalg_ffi_lowering(f"{target_name_prefix}solver_potrf_ffi", + avals_out=[operand_aval, info_aval], + operand_output_aliases={0: 0}) + result, info = rule(ctx, operand, lower=True) + ok = mlir.compare_hlo(info, mlir.full_like_aval(ctx, 0, info_aval), "EQ", + "SIGNED") + return [_replace_not_ok_with_nan(ctx, batch_dims, ok, result, out_aval)] + + cholesky_p = standard_linalg_primitive( (_float | _complex,), (2,), _cholesky_shape_rule, "cholesky") ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule mlir.register_lowering(cholesky_p, _cholesky_lowering) mlir.register_lowering(cholesky_p, _cholesky_cpu_lowering, platform="cpu") +if jaxlib_version >= (0, 8, 0): + register_cpu_gpu_lowering(cholesky_p, _cholesky_gpu_lowering, + supported_platforms=("cuda", "rocm")) # Cholesky update diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 8501c959f727..7152335d9cce 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -44,6 +44,8 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA", CsrlsvqrFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", OrgqrFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_potrf_ffi", "CUDA", + PotrfFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA", SyevdFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_sytrd_ffi", "CUDA", diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 53ee0aa1f0e1..6fcd3155c2b5 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -30,6 +30,7 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); + dict[JAX_GPU_PREFIX "solver_potrf_ffi"] = EncapsulateFfiHandler(PotrfFfi); dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi); dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi); diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc index f439413215b2..6ef8652ade37 100644 --- a/jaxlib/gpu/solver_interface.cc +++ b/jaxlib/gpu/solver_interface.cc @@ -131,6 +131,46 @@ JAX_GPU_DEFINE_ORGQR(gpuComplex, gpusolverDnCungqr); JAX_GPU_DEFINE_ORGQR(gpuDoubleComplex, gpusolverDnZungqr); #undef JAX_GPU_DEFINE_ORGQR +// Cholesky decomposition: potrf + +#define JAX_GPU_DEFINE_POTRF(Type, Name) \ + template <> \ + absl::StatusOr PotrfBufferSize(gpusolverDnHandle_t handle, \ + gpusolverFillMode_t uplo, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, uplo, n, /*A=*/nullptr, n, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Potrf(gpusolverDnHandle_t handle, \ + gpusolverFillMode_t uplo, int n, Type *a, \ + Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, uplo, n, a, n, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_POTRF(float, gpusolverDnSpotrf); +JAX_GPU_DEFINE_POTRF(double, gpusolverDnDpotrf); +JAX_GPU_DEFINE_POTRF(gpuComplex, gpusolverDnCpotrf); +JAX_GPU_DEFINE_POTRF(gpuDoubleComplex, gpusolverDnZpotrf); +#undef JAX_GPU_DEFINE_POTRF + +#define JAX_GPU_DEFINE_POTRF_BATCHED(Type, Name) \ + template <> \ + absl::Status PotrfBatched(gpusolverDnHandle_t handle, \ + gpusolverFillMode_t uplo, int n, Type **a, \ + int lda, int *info, int batch) { \ + return JAX_AS_STATUS(Name(handle, uplo, n, a, lda, info, batch)); \ + } + +JAX_GPU_DEFINE_POTRF_BATCHED(float, gpusolverDnSpotrfBatched); +JAX_GPU_DEFINE_POTRF_BATCHED(double, gpusolverDnDpotrfBatched); +JAX_GPU_DEFINE_POTRF_BATCHED(gpuComplex, gpusolverDnCpotrfBatched); +JAX_GPU_DEFINE_POTRF_BATCHED(gpuDoubleComplex, gpusolverDnZpotrfBatched); +#undef JAX_GPU_DEFINE_POTRF_BATCHED + // Symmetric (Hermitian) eigendecomposition: // * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) // * QR algorithm: syevd/heevd diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h index fa11f3d0e752..e0a44b87527a 100644 --- a/jaxlib/gpu/solver_interface.h +++ b/jaxlib/gpu/solver_interface.h @@ -117,6 +117,25 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, OrgqrBufferSize); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Orgqr); #undef JAX_GPU_SOLVER_Orgqr_ARGS +// Cholesky decomposition: potrf + +#define JAX_GPU_SOLVER_PotrfBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverFillMode_t uplo, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, PotrfBufferSize); +#undef JAX_GPU_SOLVER_PotrfBufferSize_ARGS + +#define JAX_GPU_SOLVER_Potrf_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverFillMode_t uplo, int n, Type *a, \ + Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Potrf); +#undef JAX_GPU_SOLVER_Potrf_ARGS + +#define JAX_GPU_SOLVER_PotrfBatched_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverFillMode_t uplo, int n, Type **a, \ + int lda, int *info, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, PotrfBatched); +#undef JAX_GPU_SOLVER_PotrfBatched_ARGS + // Symmetric (Hermitian) eigendecomposition: // * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) // * QR algorithm: syevd/heevd diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 49d6aff79f76..4e39545053ac 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -389,6 +389,113 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, .Ret() // out ); +// Cholesky decomposition: potrf + +template +ffi::Error PotrfImpl(int64_t batch, int64_t size, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + + FFI_ASSIGN_OR_RETURN(int lwork, + solver::PotrfBufferSize(handle.get(), uplo, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "potrf")); + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + int out_step = n * n; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Potrf(handle.get(), uplo, n, + out_data, workspace, lwork, + info_data)); + out_data += out_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template +ffi::Error PotrfBatchedImpl(int64_t batch, int64_t size, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(auto batch_ptrs, + AllocateWorkspace(scratch, batch, "batched potrf")); + + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + + auto a_data = a.untyped_data(); + auto out_data = out->untyped_data(); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch, + sizeof(T) * n * n); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); + + FFI_RETURN_IF_ERROR_STATUS(solver::PotrfBatched( + handle.get(), uplo, n, batch_ptrs, n, info_data, batch)); + + return ffi::Error::Success(); +} + +ffi::Error PotrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool lower, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result> info) { + auto dataType = a.element_type(); + if (dataType != out->element_type()) { + return ffi::Error::InvalidArgument( + "The input and output to potrf must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to potrf must be square"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "potrf")); + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "potrf")); + if (batch > 1) { + SOLVER_DISPATCH_IMPL(PotrfBatchedImpl, batch, rows, stream, scratch, lower, + a, out, info); + } else { + SOLVER_DISPATCH_IMPL(PotrfImpl, batch, rows, stream, scratch, lower, a, + out, info); + } + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in potrf", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(PotrfFfi, PotrfDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("lower") + .Arg() // a + .Ret() // out + .Ret>() // info +); + // Symmetric (Hermitian) eigendecomposition: // * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) // * QR algorithm: syevd/heevd diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index e443509ce89f..668932a7b06a 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -33,6 +33,7 @@ enum class SyevdAlgorithm : uint8_t { XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(PotrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdFfi); diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 5b9d71360f5e..62268a7d8e25 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -197,6 +197,18 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusolverDnDorgqr_bufferSize cusolverDnDorgqr_bufferSize #define gpusolverDnCungqr_bufferSize cusolverDnCungqr_bufferSize #define gpusolverDnZungqr_bufferSize cusolverDnZungqr_bufferSize +#define gpusolverDnSpotrf cusolverDnSpotrf +#define gpusolverDnDpotrf cusolverDnDpotrf +#define gpusolverDnCpotrf cusolverDnCpotrf +#define gpusolverDnZpotrf cusolverDnZpotrf +#define gpusolverDnSpotrf_bufferSize cusolverDnSpotrf_bufferSize +#define gpusolverDnDpotrf_bufferSize cusolverDnDpotrf_bufferSize +#define gpusolverDnCpotrf_bufferSize cusolverDnCpotrf_bufferSize +#define gpusolverDnZpotrf_bufferSize cusolverDnZpotrf_bufferSize +#define gpusolverDnSpotrfBatched cusolverDnSpotrfBatched +#define gpusolverDnDpotrfBatched cusolverDnDpotrfBatched +#define gpusolverDnCpotrfBatched cusolverDnCpotrfBatched +#define gpusolverDnZpotrfBatched cusolverDnZpotrfBatched #define gpusolverDnSsyevd cusolverDnSsyevd #define gpusolverDnDsyevd cusolverDnDsyevd #define gpusolverDnCheevd cusolverDnCheevd @@ -591,6 +603,18 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusolverDnDorgqr_bufferSize hipsolverDorgqr_bufferSize #define gpusolverDnCungqr_bufferSize hipsolverCungqr_bufferSize #define gpusolverDnZungqr_bufferSize hipsolverZungqr_bufferSize +#define gpusolverDnSpotrf hipsolverSpotrf +#define gpusolverDnDpotrf hipsolverDpotrf +#define gpusolverDnCpotrf hipsolverCpotrf +#define gpusolverDnZpotrf hipsolverZpotrf +#define gpusolverDnSpotrf_bufferSize hipsolverSpotrf_bufferSize +#define gpusolverDnDpotrf_bufferSize hipsolverDpotrf_bufferSize +#define gpusolverDnCpotrf_bufferSize hipsolverCpotrf_bufferSize +#define gpusolverDnZpotrf_bufferSize hipsolverZpotrf_bufferSize +#define gpusolverDnSpotrfBatched hipsolverDnSpotrfBatched +#define gpusolverDnDpotrfBatched hipsolverDnDpotrfBatched +#define gpusolverDnCpotrfBatched hipsolverDnCpotrfBatched +#define gpusolverDnZpotrfBatched hipsolverDnZpotrfBatched #define gpusolverDnSsyevd hipsolverSsyevd #define gpusolverDnDsyevd hipsolverDsyevd #define gpusolverDnCheevd hipsolverCheevd diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index f147fb355c8d..10ea2c5869ea 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -168,6 +168,7 @@ def test_custom_call_coverage(self): "hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi", "hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", "hipsolver_syevd_ffi", "hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi", + "cusolver_potrf_ffi", "hipsolver_potrf_ffi", }) not_covered = targets_to_cover.difference(covered_targets) self.assertEmpty(not_covered,