Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
mode is not enabled.
* {func}`jax.dlpack.from_dlpack` now accepts arrays with non-default layouts,
for example, transposed.
* The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses
cusolver. The magma and LAPACK implementations are still available via the
new `implementation` argument to {func}`jax.lax.linalg.eig`
({jax-issue}`#27265`). The `use_magma` argument is now deprecated in favor
of `implementation`.

* Deprecations
* {func}`jax.experimental.enable_x64` and {func}`jax.experimental.disable_x64`
Expand Down
201 changes: 160 additions & 41 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import math
import string
from typing import Any, Literal, overload
import warnings

import numpy as np

Expand All @@ -40,9 +41,11 @@
from jax._src.lax import lax
from jax._src.lax import utils as lax_utils
from jax._src.lax.lax import _float, _complex, _int
from jax._src.lib import cuda_versions
from jax._src.lib import gpu_linalg
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import version as jaxlib_version
from jax._src.lib import lapack
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
Expand Down Expand Up @@ -119,12 +122,19 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array:
r_matrix, w_vector = core.standard_insert_pvary(r_matrix, w_vector)
return cholesky_update_p.bind(r_matrix, w_vector)

class EigImplementation(enum.Enum):
"""Enum for SVD algorithm."""
CUSOLVER = "cusolver"
MAGMA = "magma"
LAPACK = "lapack"


def eig(
x: ArrayLike,
*,
compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True,
implementation: EigImplementation | None = None,
use_magma: bool | None = None,
) -> list[Array]:
"""Eigendecomposition of a general matrix.
Expand Down Expand Up @@ -159,11 +169,22 @@ def eig(
compute_left_eigenvectors: If true, the left eigenvectors will be computed.
compute_right_eigenvectors: If true, the right eigenvectors will be
computed.
use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the
eigendecomposition is computed using MAGMA. If ``False``, the computation
is done using LAPACK on to the host CPU. If ``None`` (default), the
behavior is controlled by the ``jax_use_magma`` flag. This argument
is only used on GPU.
use_magma: Deprecated, please use ``implementation`` instead. Locally
override the ``jax_use_magma`` flag. If ``True``, the eigendecomposition
is computed using MAGMA. If ``False``, the computation is done using
LAPACK on to the host CPU. If ``None`` (default), the behavior is
controlled by the ``jax_use_magma`` flag. This argument is only used on
GPU. Will be removed in JAX 0.9.
implementation: Controls the choice of eigendecomposition algorithm. If
``LAPACK``, the computation will be performed using LAPACK on the host CPU.
If ``MAGMA``, the computation will be performed using the MAGMA library on
the GPU. If ``CUSOLVER``, the computation will be performed using the
Cusolver library on the GPU. The ``CUSOLVER`` implementation requires
Cusolver 11.7.1 (from CUDA 12.6 update 2) to be installed, and does not
support computing left eigenvectors.
If ``None`` (default), an automatic choice will be made, depending on the
Cusolver version, whether left eigenvectors were requested, and the
``jax_use_magma`` configuration variable.

Returns:
The eigendecomposition of ``x``, which is a tuple of the form
Expand All @@ -175,9 +196,19 @@ def eig(
If the eigendecomposition fails, then arrays full of NaNs will be returned
for that batch element.
"""
if use_magma is not None:
warnings.warn(
"use_magma is deprecated, please use"
" implementation=EigImplementation.MAGMA instead.",
DeprecationWarning,
stacklevel=2,
)
implementation = (
EigImplementation.MAGMA if use_magma else EigImplementation.LAPACK
)
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors,
use_magma=use_magma)
implementation=implementation)


def eigh(
Expand Down Expand Up @@ -926,8 +957,9 @@ def _eig_compute_attr(compute):
)

def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
compute_right_eigenvectors, use_magma):
del use_magma # unused
compute_right_eigenvectors, implementation):
if implementation and implementation != EigImplementation.LAPACK:
raise ValueError("Only the lapack implementation is supported on CPU.")
operand_aval, = ctx.avals_in
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
Expand Down Expand Up @@ -961,48 +993,136 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
output.append(vr)
return output

def _unpack_conjugate_pairs(w, vr):
# cusolver, like LAPACK, uses a packed representation of the complex
# eigenvectors, where the (re, im) vectors are adjacent and shared by the
# conjugate pair:
# https://docs.nvidia.com/cuda/cusolver/index.html?highlight=geev#cusolverdnxgeev
if w.size == 0:
return lax.complex(vr, lax.zeros_like_array(vr))

is_real = ((w.imag == 0) | (w.imag == np.nan))
# Finds the positions at which each conjugate pair starts, via the parity of
# the count of the number of complex numbers seen.
conj_pair_start = control_flow.cumsum((~is_real).astype(int),
axis=len(w.shape) - 1)
conj_pair_start = conj_pair_start % 2 == 1
pads = [(0, 0, 0)] * (len(vr.shape))
pads[-1] = (-1, 1, 0)
vr_shifted_left = lax.pad(vr, lax._zero(vr), pads)
pads[-1] = (1, -1, 0)
vr_shifted_right = lax.pad(vr, lax._zero(vr), pads)
dims = np.delete(np.arange(len(vr.shape), dtype=np.int32), -2)
is_real = lax.broadcast_in_dim(is_real, vr.shape, broadcast_dimensions=dims)
conj_pair_start = lax.broadcast_in_dim(conj_pair_start, vr.shape,
broadcast_dimensions=dims)
re = lax.select(is_real | conj_pair_start, vr, vr_shifted_right)
im = lax.select(conj_pair_start, vr_shifted_left, -vr)
im = lax.select(is_real, lax.zeros_like_array(vr), im)
return lax.complex(re, im)


def _eig_gpu_lowering(ctx, operand, *,
compute_left_eigenvectors, compute_right_eigenvectors,
use_magma, target_name_prefix):
implementation, target_name_prefix):
operand_aval, = ctx.avals_in
batch_dims = operand_aval.shape[:-2]
n, m = operand_aval.shape[-2:]
assert n == m

gpu_solver.initialize_hybrid_kernels()
dtype = operand_aval.dtype
is_real = dtype == np.float32 or dtype == np.float64
if is_real:
target_name = f"{target_name_prefix}hybrid_eig_real"
complex_dtype = np.complex64 if dtype == np.float32 else np.complex128
complex_dtype = np.result_type(dtype, 1j)
if dtype in (np.float32, np.float64):
is_real = True
elif dtype in (np.complex64, np.complex128):
is_real = False
else:
target_name = f"{target_name_prefix}hybrid_eig_comp"
assert dtype == np.complex64 or dtype == np.complex128
complex_dtype = dtype

avals_out = [
ShapedArray(batch_dims + (n,), dtype),
ShapedArray(batch_dims + (n, n), complex_dtype),
ShapedArray(batch_dims + (n, n), complex_dtype),
ShapedArray(batch_dims, np.int32),
]
if is_real:
avals_out = [ShapedArray(batch_dims + (n,), dtype)] + avals_out
raise ValueError(f"Unsupported dtype: {dtype}")

magma = config.gpu_use_magma.value
if use_magma is not None:
magma = "on" if use_magma else "off"
have_cusolver_geev = (
target_name_prefix == "cu"
and jaxlib_version >= (0, 8)
and cuda_versions
and cuda_versions.cusolver_get_version() >= 11701
)

rule = _linalg_ffi_lowering(target_name, avals_out=avals_out)
*w, vl, vr, info = rule(ctx, operand, magma=magma,
left=compute_left_eigenvectors,
right=compute_right_eigenvectors)
if is_real:
assert len(w) == 2
w = hlo.complex(*w)
if (
implementation is None and have_cusolver_geev
and not compute_left_eigenvectors
) or implementation == EigImplementation.CUSOLVER:
if not have_cusolver_geev:
raise RuntimeError(
"Nonsymmetric eigendecomposition requires jaxlib 0.8 and cusolver"
" 11.7.1 or newer"
)
if compute_left_eigenvectors:
raise NotImplementedError(
"Left eigenvectors are not supported by cusolver")
target_name = f"{target_name_prefix}solver_geev_ffi"
avals_out = [
ShapedArray(batch_dims + (n, n), dtype),
ShapedArray(batch_dims + (n,), complex_dtype),
ShapedArray(batch_dims + (n, n), dtype),
ShapedArray(batch_dims + (n, n), dtype),
ShapedArray(batch_dims, np.int32),
]

rule = _linalg_ffi_lowering(target_name, avals_out=avals_out)
_, w, vl, vr, info = rule(ctx, operand, left=compute_left_eigenvectors,
right=compute_right_eigenvectors)
if is_real:
unpack = mlir.lower_fun(_unpack_conjugate_pairs, multiple_results=False)
if compute_left_eigenvectors:
sub_ctx = ctx.replace(
primitive=None,
avals_in=[
ShapedArray(batch_dims + (n,), complex_dtype),
ShapedArray(batch_dims + (n, n), dtype),
],
avals_out=[ShapedArray(batch_dims + (n, n), complex_dtype)],
)
vl, = unpack(sub_ctx, w, vl)
if compute_right_eigenvectors:
sub_ctx = ctx.replace(
primitive=None,
avals_in=[
ShapedArray(batch_dims + (n,), complex_dtype),
ShapedArray(batch_dims + (n, n), dtype),
],
avals_out=[ShapedArray(batch_dims + (n, n), complex_dtype)],
)
vr, = unpack(sub_ctx, w, vr)
else:
assert len(w) == 1
w = w[0]
magma = config.gpu_use_magma.value
if implementation is not None:
magma = "on" if implementation == EigImplementation.MAGMA else "off"
gpu_solver.initialize_hybrid_kernels()
if is_real:
target_name = f"{target_name_prefix}hybrid_eig_real"
complex_dtype = np.complex64 if dtype == np.float32 else np.complex128
else:
target_name = f"{target_name_prefix}hybrid_eig_comp"
assert dtype == np.complex64 or dtype == np.complex128
complex_dtype = dtype

avals_out = [
ShapedArray(batch_dims + (n,), dtype),
ShapedArray(batch_dims + (n, n), complex_dtype),
ShapedArray(batch_dims + (n, n), complex_dtype),
ShapedArray(batch_dims, np.int32),
]
if is_real:
avals_out = [ShapedArray(batch_dims + (n,), dtype)] + avals_out
rule = _linalg_ffi_lowering(target_name, avals_out=avals_out)
*w, vl, vr, info = rule(ctx, operand, magma=magma,
left=compute_left_eigenvectors,
right=compute_right_eigenvectors)
if is_real:
assert len(w) == 2
w = hlo.complex(*w)
else:
assert len(w) == 1
w = w[0]
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.int32))
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
w_aval = ShapedArray(batch_dims + (n,), complex_dtype)
Expand All @@ -1019,8 +1139,7 @@ def _eig_gpu_lowering(ctx, operand, *,
return output

def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
compute_right_eigenvectors, use_magma):
del use_magma # unused
compute_right_eigenvectors, implementation):
if compute_left_eigenvectors or compute_right_eigenvectors:
raise NotImplementedError(
'The derivatives of eigenvectors are not implemented, only '
Expand All @@ -1030,7 +1149,7 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
# https://arxiv.org/abs/1701.00392
a, = primals
da, = tangents
l, v = eig(a, compute_left_eigenvectors=False)
l, v = eig(a, compute_left_eigenvectors=False, implementation=implementation)
return [l], [(_solve(v, da.astype(v.dtype)) * _T(v)).sum(-1)]

eig_p = linalg_primitive(
Expand Down
1 change: 1 addition & 0 deletions jax/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
cholesky_p as cholesky_p,
cholesky_update as cholesky_update,
cholesky_update_p as cholesky_update_p,
EigImplementation as EigImplementation,
eig as eig,
eig_p as eig_p,
eigh as eigh,
Expand Down
4 changes: 4 additions & 0 deletions jaxlib/gpu/solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ nb::dict Registrations() {
EncapsulateFfiHandler(CsrlsvqrFfi);
#endif // JAX_GPU_CUDA

#if JAX_GPU_HAVE_SOLVER_GEEV
dict[JAX_GPU_PREFIX "solver_geev_ffi"] = EncapsulateFfiHandler(GeevFfi);
#endif // JAX_GPU_HAVE_SOLVER_GEEV

return dict;
}

Expand Down
Loading
Loading