Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,8 @@ def _check_lowering(lowering) -> None:
# qr on GPU
"cusolver_geqrf_ffi", "cusolver_orgqr_ffi",
"hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi",
# cholesky on GPU
"cusolver_potrf_ffi", "hipsolver_potrf_ffi",
# eigh on GPU
"cusolver_syevd_ffi", "hipsolver_syevd_ffi",
# svd on GPU
Expand Down
22 changes: 21 additions & 1 deletion jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
from jax._src.lib import gpu_linalg
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import version as jaxlib_version
from jax._src.lib import lapack
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
Expand Down Expand Up @@ -865,11 +865,31 @@ def _cholesky_cpu_lowering(ctx, operand):
return [_replace_not_ok_with_nan(ctx, batch_dims, ok, result, out_aval)]


def _cholesky_gpu_lowering(ctx, operand, *, target_name_prefix):
# TODO(phawkins): remove forward compat path after Nov 10, 2025.
if ctx.is_forward_compat():
return _cholesky_lowering(ctx, operand)
operand_aval, = ctx.avals_in
out_aval, = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
info_aval = ShapedArray(batch_dims, np.int32)
rule = _linalg_ffi_lowering(f"{target_name_prefix}solver_potrf_ffi",
avals_out=[operand_aval, info_aval],
operand_output_aliases={0: 0})
result, info = rule(ctx, operand, lower=True)
ok = mlir.compare_hlo(info, mlir.full_like_aval(ctx, 0, info_aval), "EQ",
"SIGNED")
return [_replace_not_ok_with_nan(ctx, batch_dims, ok, result, out_aval)]


cholesky_p = standard_linalg_primitive(
(_float | _complex,), (2,), _cholesky_shape_rule, "cholesky")
ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
mlir.register_lowering(cholesky_p, _cholesky_lowering)
mlir.register_lowering(cholesky_p, _cholesky_cpu_lowering, platform="cpu")
if jaxlib_version >= (0, 8, 0):
register_cpu_gpu_lowering(cholesky_p, _cholesky_gpu_lowering,
supported_platforms=("cuda", "rocm"))


# Cholesky update
Expand Down
2 changes: 2 additions & 0 deletions jaxlib/gpu/gpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA",
CsrlsvqrFfi);
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA",
OrgqrFfi);
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_potrf_ffi", "CUDA",
PotrfFfi);
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA",
SyevdFfi);
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_sytrd_ffi", "CUDA",
Expand Down
1 change: 1 addition & 0 deletions jaxlib/gpu/solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ nb::dict Registrations() {
dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi);
dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi);
dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi);
dict[JAX_GPU_PREFIX "solver_potrf_ffi"] = EncapsulateFfiHandler(PotrfFfi);
dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi);
dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi);
dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi);
Expand Down
40 changes: 40 additions & 0 deletions jaxlib/gpu/solver_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,46 @@ JAX_GPU_DEFINE_ORGQR(gpuComplex, gpusolverDnCungqr);
JAX_GPU_DEFINE_ORGQR(gpuDoubleComplex, gpusolverDnZungqr);
#undef JAX_GPU_DEFINE_ORGQR

// Cholesky decomposition: potrf

#define JAX_GPU_DEFINE_POTRF(Type, Name) \
template <> \
absl::StatusOr<int> PotrfBufferSize<Type>(gpusolverDnHandle_t handle, \
gpusolverFillMode_t uplo, int n) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, uplo, n, /*A=*/nullptr, n, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Potrf<Type>(gpusolverDnHandle_t handle, \
gpusolverFillMode_t uplo, int n, Type *a, \
Type *workspace, int lwork, int *info) { \
return JAX_AS_STATUS( \
Name(handle, uplo, n, a, n, workspace, lwork, info)); \
}

JAX_GPU_DEFINE_POTRF(float, gpusolverDnSpotrf);
JAX_GPU_DEFINE_POTRF(double, gpusolverDnDpotrf);
JAX_GPU_DEFINE_POTRF(gpuComplex, gpusolverDnCpotrf);
JAX_GPU_DEFINE_POTRF(gpuDoubleComplex, gpusolverDnZpotrf);
#undef JAX_GPU_DEFINE_POTRF

#define JAX_GPU_DEFINE_POTRF_BATCHED(Type, Name) \
template <> \
absl::Status PotrfBatched<Type>(gpusolverDnHandle_t handle, \
gpusolverFillMode_t uplo, int n, Type **a, \
int lda, int *info, int batch) { \
return JAX_AS_STATUS(Name(handle, uplo, n, a, lda, info, batch)); \
}

JAX_GPU_DEFINE_POTRF_BATCHED(float, gpusolverDnSpotrfBatched);
JAX_GPU_DEFINE_POTRF_BATCHED(double, gpusolverDnDpotrfBatched);
JAX_GPU_DEFINE_POTRF_BATCHED(gpuComplex, gpusolverDnCpotrfBatched);
JAX_GPU_DEFINE_POTRF_BATCHED(gpuDoubleComplex, gpusolverDnZpotrfBatched);
#undef JAX_GPU_DEFINE_POTRF_BATCHED

// Symmetric (Hermitian) eigendecomposition:
// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32)
// * QR algorithm: syevd/heevd
Expand Down
19 changes: 19 additions & 0 deletions jaxlib/gpu/solver_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, OrgqrBufferSize);
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Orgqr);
#undef JAX_GPU_SOLVER_Orgqr_ARGS

// Cholesky decomposition: potrf

#define JAX_GPU_SOLVER_PotrfBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, gpusolverFillMode_t uplo, int n
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, PotrfBufferSize);
#undef JAX_GPU_SOLVER_PotrfBufferSize_ARGS

#define JAX_GPU_SOLVER_Potrf_ARGS(Type, ...) \
gpusolverDnHandle_t handle, gpusolverFillMode_t uplo, int n, Type *a, \
Type *workspace, int lwork, int *info
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Potrf);
#undef JAX_GPU_SOLVER_Potrf_ARGS

#define JAX_GPU_SOLVER_PotrfBatched_ARGS(Type, ...) \
gpusolverDnHandle_t handle, gpusolverFillMode_t uplo, int n, Type **a, \
int lda, int *info, int batch
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, PotrfBatched);
#undef JAX_GPU_SOLVER_PotrfBatched_ARGS

// Symmetric (Hermitian) eigendecomposition:
// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32)
// * QR algorithm: syevd/heevd
Expand Down
107 changes: 107 additions & 0 deletions jaxlib/gpu/solver_kernels_ffi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,113 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
.Ret<ffi::AnyBuffer>() // out
);

// Cholesky decomposition: potrf

template <typename T>
ffi::Error PotrfImpl(int64_t batch, int64_t size, gpuStream_t stream,
ffi::ScratchAllocator& scratch, bool lower,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(size));
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));

gpusolverFillMode_t uplo =
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;

FFI_ASSIGN_OR_RETURN(int lwork,
solver::PotrfBufferSize<T>(handle.get(), uplo, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "potrf"));

auto a_data = static_cast<T*>(a.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
auto info_data = info->typed_data();
if (a_data != out_data) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}

int out_step = n * n;
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(solver::Potrf<T>(handle.get(), uplo, n,
out_data, workspace, lwork,
info_data));
out_data += out_step;
++info_data;
}
return ffi::Error::Success();
}

template <typename T>
ffi::Error PotrfBatchedImpl(int64_t batch, int64_t size, gpuStream_t stream,
ffi::ScratchAllocator& scratch, bool lower,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(size));
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(auto batch_ptrs,
AllocateWorkspace<T*>(scratch, batch, "batched potrf"));

gpusolverFillMode_t uplo =
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;

auto a_data = a.untyped_data();
auto out_data = out->untyped_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}

MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch,
sizeof(T) * n * n);
JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError());

FFI_RETURN_IF_ERROR_STATUS(solver::PotrfBatched<T>(
handle.get(), uplo, n, batch_ptrs, n, info_data, batch));

return ffi::Error::Success();
}

ffi::Error PotrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
bool lower, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
auto dataType = a.element_type();
if (dataType != out->element_type()) {
return ffi::Error::InvalidArgument(
"The input and output to potrf must have the same element type");
}
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(a.dimensions()));
if (rows != cols) {
return ffi::Error::InvalidArgument(
"The input matrix to potrf must be square");
}
FFI_RETURN_IF_ERROR(
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "potrf"));
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "potrf"));
if (batch > 1) {
SOLVER_DISPATCH_IMPL(PotrfBatchedImpl, batch, rows, stream, scratch, lower,
a, out, info);
} else {
SOLVER_DISPATCH_IMPL(PotrfImpl, batch, rows, stream, scratch, lower, a,
out, info);
}
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in potrf", absl::FormatStreamed(dataType)));
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(PotrfFfi, PotrfDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Attr<bool>("lower")
.Arg<ffi::AnyBuffer>() // a
.Ret<ffi::AnyBuffer>() // out
.Ret<ffi::Buffer<ffi::S32>>() // info
);

// Symmetric (Hermitian) eigendecomposition:
// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32)
// * QR algorithm: syevd/heevd
Expand Down
1 change: 1 addition & 0 deletions jaxlib/gpu/solver_kernels_ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ enum class SyevdAlgorithm : uint8_t {
XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(PotrfFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdFfi);
Expand Down
24 changes: 24 additions & 0 deletions jaxlib/gpu/vendor.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,18 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpusolverDnDorgqr_bufferSize cusolverDnDorgqr_bufferSize
#define gpusolverDnCungqr_bufferSize cusolverDnCungqr_bufferSize
#define gpusolverDnZungqr_bufferSize cusolverDnZungqr_bufferSize
#define gpusolverDnSpotrf cusolverDnSpotrf
#define gpusolverDnDpotrf cusolverDnDpotrf
#define gpusolverDnCpotrf cusolverDnCpotrf
#define gpusolverDnZpotrf cusolverDnZpotrf
#define gpusolverDnSpotrf_bufferSize cusolverDnSpotrf_bufferSize
#define gpusolverDnDpotrf_bufferSize cusolverDnDpotrf_bufferSize
#define gpusolverDnCpotrf_bufferSize cusolverDnCpotrf_bufferSize
#define gpusolverDnZpotrf_bufferSize cusolverDnZpotrf_bufferSize
#define gpusolverDnSpotrfBatched cusolverDnSpotrfBatched
#define gpusolverDnDpotrfBatched cusolverDnDpotrfBatched
#define gpusolverDnCpotrfBatched cusolverDnCpotrfBatched
#define gpusolverDnZpotrfBatched cusolverDnZpotrfBatched
#define gpusolverDnSsyevd cusolverDnSsyevd
#define gpusolverDnDsyevd cusolverDnDsyevd
#define gpusolverDnCheevd cusolverDnCheevd
Expand Down Expand Up @@ -591,6 +603,18 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpusolverDnDorgqr_bufferSize hipsolverDorgqr_bufferSize
#define gpusolverDnCungqr_bufferSize hipsolverCungqr_bufferSize
#define gpusolverDnZungqr_bufferSize hipsolverZungqr_bufferSize
#define gpusolverDnSpotrf hipsolverSpotrf
#define gpusolverDnDpotrf hipsolverDpotrf
#define gpusolverDnCpotrf hipsolverCpotrf
#define gpusolverDnZpotrf hipsolverZpotrf
#define gpusolverDnSpotrf_bufferSize hipsolverSpotrf_bufferSize
#define gpusolverDnDpotrf_bufferSize hipsolverDpotrf_bufferSize
#define gpusolverDnCpotrf_bufferSize hipsolverCpotrf_bufferSize
#define gpusolverDnZpotrf_bufferSize hipsolverZpotrf_bufferSize
#define gpusolverDnSpotrfBatched hipsolverDnSpotrfBatched
#define gpusolverDnDpotrfBatched hipsolverDnDpotrfBatched
#define gpusolverDnCpotrfBatched hipsolverDnCpotrfBatched
#define gpusolverDnZpotrfBatched hipsolverDnZpotrfBatched
#define gpusolverDnSsyevd hipsolverSsyevd
#define gpusolverDnDsyevd hipsolverDsyevd
#define gpusolverDnCheevd hipsolverCheevd
Expand Down
1 change: 1 addition & 0 deletions tests/export_back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def test_custom_call_coverage(self):
"hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi",
"hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi", "hipsolver_syevd_ffi",
"hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi",
"cusolver_potrf_ffi", "hipsolver_potrf_ffi",
})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered,
Expand Down
Loading