From 97b001a9a927006bd4d0c37f7288f0aa16a25306 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 10 Oct 2025 07:01:50 -0700 Subject: [PATCH] Switch the default nonsymmetric eigendecomposition implementation to use cusolver. This regresses a test that checks that eigendecompositions behave well if provided NaNs as input. I think this is the right tradeoff, however, since the cusolver implementation should be faster and the old implementation is still available. Deprecate the use_magma argument to `lax.linalg.eig`, replace it with an `implementation` argument. This PR builds on an partial PR from @dfm. Fixes https://github.com/jax-ml/jax/issues/27265 PiperOrigin-RevId: 817628185 --- CHANGELOG.md | 5 + jax/_src/lax/linalg.py | 201 ++++++++++++++++++++++++------- jax/lax/linalg.py | 1 + jaxlib/gpu/solver.cc | 4 + jaxlib/gpu/solver_kernels_ffi.cc | 151 +++++++++++++++++++++-- jaxlib/gpu/solver_kernels_ffi.h | 4 + jaxlib/gpu/vendor.h | 10 ++ tests/linalg_test.py | 40 ++++-- 8 files changed, 353 insertions(+), 63 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c53808951c3..d8f2eb492197 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. mode is not enabled. * {func}`jax.dlpack.from_dlpack` now accepts arrays with non-default layouts, for example, transposed. + * The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses + cusolver. The magma and LAPACK implementations are still available via the + new `implementation` argument to {func}`jax.lax.linalg.eig` + ({jax-issue}`#27265`). The `use_magma` argument is now deprecated in favor + of `implementation`. * Deprecations * {func}`jax.experimental.enable_x64` and {func}`jax.experimental.disable_x64` diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index a7492c828379..5b3f03fd3a16 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -20,6 +20,7 @@ import math import string from typing import Any, Literal, overload +import warnings import numpy as np @@ -40,9 +41,11 @@ from jax._src.lax import lax from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int +from jax._src.lib import cuda_versions from jax._src.lib import gpu_linalg from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse +from jax._src.lib import version as jaxlib_version from jax._src.lib import lapack from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo @@ -119,12 +122,19 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: r_matrix, w_vector = core.standard_insert_pvary(r_matrix, w_vector) return cholesky_update_p.bind(r_matrix, w_vector) +class EigImplementation(enum.Enum): + """Enum for SVD algorithm.""" + CUSOLVER = "cusolver" + MAGMA = "magma" + LAPACK = "lapack" + def eig( x: ArrayLike, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors: bool = True, + implementation: EigImplementation | None = None, use_magma: bool | None = None, ) -> list[Array]: """Eigendecomposition of a general matrix. @@ -159,11 +169,22 @@ def eig( compute_left_eigenvectors: If true, the left eigenvectors will be computed. compute_right_eigenvectors: If true, the right eigenvectors will be computed. - use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the - eigendecomposition is computed using MAGMA. If ``False``, the computation - is done using LAPACK on to the host CPU. If ``None`` (default), the - behavior is controlled by the ``jax_use_magma`` flag. This argument - is only used on GPU. + use_magma: Deprecated, please use ``implementation`` instead. Locally + override the ``jax_use_magma`` flag. If ``True``, the eigendecomposition + is computed using MAGMA. If ``False``, the computation is done using + LAPACK on to the host CPU. If ``None`` (default), the behavior is + controlled by the ``jax_use_magma`` flag. This argument is only used on + GPU. Will be removed in JAX 0.9. + implementation: Controls the choice of eigendecomposition algorithm. If + ``LAPACK``, the computation will be performed using LAPACK on the host CPU. + If ``MAGMA``, the computation will be performed using the MAGMA library on + the GPU. If ``CUSOLVER``, the computation will be performed using the + Cusolver library on the GPU. The ``CUSOLVER`` implementation requires + Cusolver 11.7.1 (from CUDA 12.6 update 2) to be installed, and does not + support computing left eigenvectors. + If ``None`` (default), an automatic choice will be made, depending on the + Cusolver version, whether left eigenvectors were requested, and the + ``jax_use_magma`` configuration variable. Returns: The eigendecomposition of ``x``, which is a tuple of the form @@ -175,9 +196,19 @@ def eig( If the eigendecomposition fails, then arrays full of NaNs will be returned for that batch element. """ + if use_magma is not None: + warnings.warn( + "use_magma is deprecated, please use" + " implementation=EigImplementation.MAGMA instead.", + DeprecationWarning, + stacklevel=2, + ) + implementation = ( + EigImplementation.MAGMA if use_magma else EigImplementation.LAPACK + ) return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors, - use_magma=use_magma) + implementation=implementation) def eigh( @@ -926,8 +957,9 @@ def _eig_compute_attr(compute): ) def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, - compute_right_eigenvectors, use_magma): - del use_magma # unused + compute_right_eigenvectors, implementation): + if implementation and implementation != EigImplementation.LAPACK: + raise ValueError("Only the lapack implementation is supported on CPU.") operand_aval, = ctx.avals_in out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] @@ -961,48 +993,136 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, output.append(vr) return output +def _unpack_conjugate_pairs(w, vr): + # cusolver, like LAPACK, uses a packed representation of the complex + # eigenvectors, where the (re, im) vectors are adjacent and shared by the + # conjugate pair: + # https://docs.nvidia.com/cuda/cusolver/index.html?highlight=geev#cusolverdnxgeev + if w.size == 0: + return lax.complex(vr, lax.zeros_like_array(vr)) + + is_real = ((w.imag == 0) | (w.imag == np.nan)) + # Finds the positions at which each conjugate pair starts, via the parity of + # the count of the number of complex numbers seen. + conj_pair_start = control_flow.cumsum((~is_real).astype(int), + axis=len(w.shape) - 1) + conj_pair_start = conj_pair_start % 2 == 1 + pads = [(0, 0, 0)] * (len(vr.shape)) + pads[-1] = (-1, 1, 0) + vr_shifted_left = lax.pad(vr, lax._zero(vr), pads) + pads[-1] = (1, -1, 0) + vr_shifted_right = lax.pad(vr, lax._zero(vr), pads) + dims = np.delete(np.arange(len(vr.shape), dtype=np.int32), -2) + is_real = lax.broadcast_in_dim(is_real, vr.shape, broadcast_dimensions=dims) + conj_pair_start = lax.broadcast_in_dim(conj_pair_start, vr.shape, + broadcast_dimensions=dims) + re = lax.select(is_real | conj_pair_start, vr, vr_shifted_right) + im = lax.select(conj_pair_start, vr_shifted_left, -vr) + im = lax.select(is_real, lax.zeros_like_array(vr), im) + return lax.complex(re, im) + + def _eig_gpu_lowering(ctx, operand, *, compute_left_eigenvectors, compute_right_eigenvectors, - use_magma, target_name_prefix): + implementation, target_name_prefix): operand_aval, = ctx.avals_in batch_dims = operand_aval.shape[:-2] n, m = operand_aval.shape[-2:] assert n == m - gpu_solver.initialize_hybrid_kernels() dtype = operand_aval.dtype - is_real = dtype == np.float32 or dtype == np.float64 - if is_real: - target_name = f"{target_name_prefix}hybrid_eig_real" - complex_dtype = np.complex64 if dtype == np.float32 else np.complex128 + complex_dtype = np.result_type(dtype, 1j) + if dtype in (np.float32, np.float64): + is_real = True + elif dtype in (np.complex64, np.complex128): + is_real = False else: - target_name = f"{target_name_prefix}hybrid_eig_comp" - assert dtype == np.complex64 or dtype == np.complex128 - complex_dtype = dtype - - avals_out = [ - ShapedArray(batch_dims + (n,), dtype), - ShapedArray(batch_dims + (n, n), complex_dtype), - ShapedArray(batch_dims + (n, n), complex_dtype), - ShapedArray(batch_dims, np.int32), - ] - if is_real: - avals_out = [ShapedArray(batch_dims + (n,), dtype)] + avals_out + raise ValueError(f"Unsupported dtype: {dtype}") - magma = config.gpu_use_magma.value - if use_magma is not None: - magma = "on" if use_magma else "off" + have_cusolver_geev = ( + target_name_prefix == "cu" + and jaxlib_version >= (0, 8) + and cuda_versions + and cuda_versions.cusolver_get_version() >= 11701 + ) - rule = _linalg_ffi_lowering(target_name, avals_out=avals_out) - *w, vl, vr, info = rule(ctx, operand, magma=magma, - left=compute_left_eigenvectors, - right=compute_right_eigenvectors) - if is_real: - assert len(w) == 2 - w = hlo.complex(*w) + if ( + implementation is None and have_cusolver_geev + and not compute_left_eigenvectors + ) or implementation == EigImplementation.CUSOLVER: + if not have_cusolver_geev: + raise RuntimeError( + "Nonsymmetric eigendecomposition requires jaxlib 0.8 and cusolver" + " 11.7.1 or newer" + ) + if compute_left_eigenvectors: + raise NotImplementedError( + "Left eigenvectors are not supported by cusolver") + target_name = f"{target_name_prefix}solver_geev_ffi" + avals_out = [ + ShapedArray(batch_dims + (n, n), dtype), + ShapedArray(batch_dims + (n,), complex_dtype), + ShapedArray(batch_dims + (n, n), dtype), + ShapedArray(batch_dims + (n, n), dtype), + ShapedArray(batch_dims, np.int32), + ] + + rule = _linalg_ffi_lowering(target_name, avals_out=avals_out) + _, w, vl, vr, info = rule(ctx, operand, left=compute_left_eigenvectors, + right=compute_right_eigenvectors) + if is_real: + unpack = mlir.lower_fun(_unpack_conjugate_pairs, multiple_results=False) + if compute_left_eigenvectors: + sub_ctx = ctx.replace( + primitive=None, + avals_in=[ + ShapedArray(batch_dims + (n,), complex_dtype), + ShapedArray(batch_dims + (n, n), dtype), + ], + avals_out=[ShapedArray(batch_dims + (n, n), complex_dtype)], + ) + vl, = unpack(sub_ctx, w, vl) + if compute_right_eigenvectors: + sub_ctx = ctx.replace( + primitive=None, + avals_in=[ + ShapedArray(batch_dims + (n,), complex_dtype), + ShapedArray(batch_dims + (n, n), dtype), + ], + avals_out=[ShapedArray(batch_dims + (n, n), complex_dtype)], + ) + vr, = unpack(sub_ctx, w, vr) else: - assert len(w) == 1 - w = w[0] + magma = config.gpu_use_magma.value + if implementation is not None: + magma = "on" if implementation == EigImplementation.MAGMA else "off" + gpu_solver.initialize_hybrid_kernels() + if is_real: + target_name = f"{target_name_prefix}hybrid_eig_real" + complex_dtype = np.complex64 if dtype == np.float32 else np.complex128 + else: + target_name = f"{target_name_prefix}hybrid_eig_comp" + assert dtype == np.complex64 or dtype == np.complex128 + complex_dtype = dtype + + avals_out = [ + ShapedArray(batch_dims + (n,), dtype), + ShapedArray(batch_dims + (n, n), complex_dtype), + ShapedArray(batch_dims + (n, n), complex_dtype), + ShapedArray(batch_dims, np.int32), + ] + if is_real: + avals_out = [ShapedArray(batch_dims + (n,), dtype)] + avals_out + rule = _linalg_ffi_lowering(target_name, avals_out=avals_out) + *w, vl, vr, info = rule(ctx, operand, magma=magma, + left=compute_left_eigenvectors, + right=compute_right_eigenvectors) + if is_real: + assert len(w) == 2 + w = hlo.complex(*w) + else: + assert len(w) == 1 + w = w[0] zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.int32)) ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED") w_aval = ShapedArray(batch_dims + (n,), complex_dtype) @@ -1019,8 +1139,7 @@ def _eig_gpu_lowering(ctx, operand, *, return output def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, - compute_right_eigenvectors, use_magma): - del use_magma # unused + compute_right_eigenvectors, implementation): if compute_left_eigenvectors or compute_right_eigenvectors: raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' @@ -1030,7 +1149,7 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, # https://arxiv.org/abs/1701.00392 a, = primals da, = tangents - l, v = eig(a, compute_left_eigenvectors=False) + l, v = eig(a, compute_left_eigenvectors=False, implementation=implementation) return [l], [(_solve(v, da.astype(v.dtype)) * _T(v)).sum(-1)] eig_p = linalg_primitive( diff --git a/jax/lax/linalg.py b/jax/lax/linalg.py index 984592534656..c3ade1ae0785 100644 --- a/jax/lax/linalg.py +++ b/jax/lax/linalg.py @@ -17,6 +17,7 @@ cholesky_p as cholesky_p, cholesky_update as cholesky_update, cholesky_update_p as cholesky_update_p, + EigImplementation as EigImplementation, eig as eig, eig_p as eig_p, eigh as eigh, diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 08d25948d893..53ee0aa1f0e1 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -41,6 +41,10 @@ nb::dict Registrations() { EncapsulateFfiHandler(CsrlsvqrFfi); #endif // JAX_GPU_CUDA +#if JAX_GPU_HAVE_SOLVER_GEEV + dict[JAX_GPU_PREFIX "solver_geev_ffi"] = EncapsulateFfiHandler(GeevFfi); +#endif // JAX_GPU_HAVE_SOLVER_GEEV + return dict; } diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 0712bcafa739..49d6aff79f76 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #if JAX_GPU_HAVE_64_BIT @@ -51,8 +50,6 @@ namespace JAX_GPU_NAMESPACE { namespace ffi = ::xla::ffi; -#if JAX_GPU_HAVE_64_BIT - // Map an FFI buffer element type to the appropriate GPU solver type. inline absl::StatusOr SolverDataType(ffi::DataType dataType, std::string_view func) { @@ -71,8 +68,6 @@ inline absl::StatusOr SolverDataType(ffi::DataType dataType, } } -#endif - #define SOLVER_DISPATCH_IMPL(impl, ...) \ switch (dataType) { \ case ffi::F32: \ @@ -434,7 +429,8 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream, params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); }); int64_t batch_step = 1; - FFI_ASSIGN_OR_RETURN(bool is_batched_syev_supported, IsSyevBatchedSupported()); + FFI_ASSIGN_OR_RETURN(bool is_batched_syev_supported, + IsSyevBatchedSupported()); if (is_batched_syev_supported) { int64_t matrix_size = n * n * ffi::ByteWidth(dataType); batch_step = std::numeric_limits::max() / matrix_size; @@ -450,7 +446,8 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream, &workspaceInBytesOnHost, std::min(batch, batch_step))); } else { if (batch_step != 1) { - return ffi::Error(ffi::ErrorCode::kInternal, + return ffi::Error( + ffi::ErrorCode::kInternal, "Syevd64Impl: batch_step != 1 but batched syev is not supported"); } JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd_bufferSize( @@ -484,17 +481,19 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream, size_t batch_size = static_cast(std::min(batch_step, batch - i)); if (is_batched_syev_supported) { JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevBatched( - handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data, - aType, workspaceOnDevice, workspaceInBytesOnDevice, - workspaceOnHost.get(), workspaceInBytesOnHost, info_data, batch_size)); + handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, + w_data, aType, workspaceOnDevice, workspaceInBytesOnDevice, + workspaceOnHost.get(), workspaceInBytesOnHost, info_data, + batch_size)); } else { if (batch_step != 1) { - return ffi::Error(ffi::ErrorCode::kInternal, + return ffi::Error( + ffi::ErrorCode::kInternal, "Syevd64Impl: batch_step != 1 but batched syev is not supported"); } JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd( - handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data, - aType, workspaceOnDevice, workspaceInBytesOnDevice, + handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, + w_data, aType, workspaceOnDevice, workspaceInBytesOnDevice, workspaceOnHost.get(), workspaceInBytesOnHost, info_data)); } out_data += out_step; @@ -1215,6 +1214,132 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SytrdFfi, SytrdDispatch, .Ret>() // info ); +// General eigenvalue decomposition: geev + +#if JAX_GPU_HAVE_SOLVER_GEEV + +ffi::Error GeevImpl(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool left, bool right, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = a.element_type(); + if (dataType != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to geev must have the same element type"); + } + if (dataType != w->element_type() && + !(dataType == ffi::F32 && w->element_type() == ffi::C64) && + !(dataType == ffi::F64 && w->element_type() == ffi::C128)) { + return ffi::Error::InvalidArgument( + "The eigenvector output type of geev must match the input type or " + "be its complex counterpart."); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, m, n]), SplitBatch2D(a.dimensions())); + if (m != n) { + return ffi::Error::InvalidArgument( + "The input matrix to geev must be square"); + } + int w_len; + if (w->element_type() == ffi::F32 || w->element_type() == ffi::F64) { + w_len = 2 * n; + FFI_RETURN_IF_ERROR( + CheckShape(w->dimensions(), {batch, 2 * n}, "w", "geev")); + } else { + FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, n}, "w", "geev")); + w_len = n; + } + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, n, n}, "vl", "geev")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, n, n}, "vr", "geev")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "geev")); + + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(auto aType, SolverDataType(dataType, "geev")); + FFI_ASSIGN_OR_RETURN(auto wType, SolverDataType(w->element_type(), "geev")); + + // At the time of writing, cusolver only supports computing right + // eigenvectors, but has the option for left eigenvectors in its API. Let us + // assume that they intend to add support for left eigenvectors in the future. + gpusolverEigMode_t jobvl = + left ? GPUSOLVER_EIG_MODE_VECTOR : GPUSOLVER_EIG_MODE_NOVECTOR; + gpusolverEigMode_t jobvr = + right ? GPUSOLVER_EIG_MODE_VECTOR : GPUSOLVER_EIG_MODE_NOVECTOR; + + gpusolverDnParams_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms)); + std::unique_ptr + params_cleanup( + params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); }); + + size_t workspaceInBytesOnDevice, workspaceInBytesOnHost; + JAX_FFI_RETURN_IF_GPU_ERROR(cusolverDnXgeev_bufferSize( + handle.get(), params, jobvl, jobvr, n, aType, /*a=*/nullptr, n, wType, + /*w=*/nullptr, aType, /*vl=*/nullptr, n, aType, /*vr=*/nullptr, n, aType, + &workspaceInBytesOnDevice, &workspaceInBytesOnHost)); + + auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice); + if (!maybe_workspace.has_value()) { + return ffi::Error(ffi::ErrorCode::kResourceExhausted, + "Unable to allocate device workspace for syevd"); + } + auto workspaceOnDevice = maybe_workspace.value(); + auto workspaceOnHost = + std::unique_ptr(new char[workspaceInBytesOnHost]); + + const char* a_data = static_cast(a.untyped_data()); + char* out_data = static_cast(out->untyped_data()); + char* w_data = static_cast(w->untyped_data()); + char* vl_data = static_cast(vl->untyped_data()); + char* vr_data = static_cast(vr->untyped_data()); + int* info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + size_t out_step = n * n * ffi::ByteWidth(dataType); + size_t w_step = w_len * ffi::ByteWidth(w->element_type()); + + for (auto i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgeev( + handle.get(), params, jobvl, jobvr, n, aType, out_data, n, wType, + w_data, aType, vl_data, n, aType, vr_data, n, aType, workspaceOnDevice, + workspaceInBytesOnDevice, workspaceOnHost.get(), workspaceInBytesOnHost, + info_data)); + out_data += out_step; + w_data += w_step; + vr_data += out_step; + ++info_data; + } + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GeevFfi, GeevImpl, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("left") + .Attr("right") + .Arg() // a + .Ret() // out + .Ret() // w + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +#endif // JAX_GPU_HAVE_SOLVER_GEEV + #undef SOLVER_DISPATCH_IMPL #undef SOLVER_BLAS_DISPATCH_IMPL diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 8e90a310e170..e443509ce89f 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -43,6 +43,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrlsvqrFfi); #endif // JAX_GPU_CUDA +#if JAX_GPU_HAVE_SOLVER_GEEV +XLA_FFI_DECLARE_HANDLER_SYMBOL(GeevFfi); +#endif // JAX_GPU_HAVE_SOLVER_GEEV + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index b8eebcae3573..5b9d71360f5e 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -417,6 +417,14 @@ typedef cusolverDnParams_t gpusolverDnParams_t; #define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize #define gpusolverDnXgesvd cusolverDnXgesvd +#if CUDA_VERSION >= 12060 +#define JAX_GPU_HAVE_SOLVER_GEEV 1 +#define gpusolverDnXgeev_bufferSize cusolverDnXgeev_bufferSize +#define gpusolverDnXgeev cusolverDnXgeev +#else +#define JAX_GPU_HAVE_SOLVER_GEEV 0 +#endif // CUDA_VERSION >= 12060 + namespace jax::JAX_GPU_NAMESPACE { namespace { constexpr uint32_t kNumThreadsPerWarp = 32; @@ -756,6 +764,8 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuGetDeviceProperties hipGetDeviceProperties #define gpuLaunchCooperativeKernel hipLaunchCooperativeKernel +#define JAX_GPU_HAVE_SOLVER_GEEV 0 + namespace jax::JAX_GPU_NAMESPACE { namespace { constexpr uint32_t kNumThreadsPerWarp = 64; diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 14ff9796b86a..ffc74ad4591f 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -30,6 +30,8 @@ from jax import scipy as jsp from jax._src import config from jax._src.lax import linalg as lax_linalg +from jax._src.lib import cuda_versions +from jax._src.lib import version as jaxlib_version from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.numpy.util import promote_dtypes_inexact @@ -290,15 +292,29 @@ def check_left_eigenvectors(a, w, vl): check_right_eigenvectors(aH, wC, vl) a, = args_maker() - results = lax.linalg.eig( - a, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors) - w = results[0] - if compute_left_eigenvectors: - check_left_eigenvectors(a, w, results[1]) - if compute_right_eigenvectors: - check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors]) + implementations = [None] + + if ( + jtu.is_device_cuda() + and not compute_left_eigenvectors + and jaxlib_version >= (0, 8) + and cuda_versions + and cuda_versions.cusolver_get_version() >= 11701 + ): + implementations.append(jax.lax.linalg.EigImplementation.CUSOLVER) + + for implementation in implementations: + results = lax.linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + implementation=implementation) + w = results[0] + + if compute_left_eigenvectors: + check_left_eigenvectors(a, w, results[1]) + if compute_right_eigenvectors: + check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors]) self._CompileAndCheck(partial(jnp.linalg.eig), args_maker, rtol=1e-3) @@ -312,10 +328,16 @@ def check_left_eigenvectors(a, w, vl): def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jtu.is_device_cuda(): + # TODO(phawkins): CUSOLVER's implementation does not pass this test. + implementation = jax.lax.linalg.EigImplementation.LAPACK + else: + implementation = None a = jnp.full(shape, jnp.nan, dtype) results = lax.linalg.eig( a, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors) + compute_right_eigenvectors=compute_right_eigenvectors, + implementation=implementation) for result in results: self.assertTrue(np.all(np.isnan(result)))