Replies: 2 comments 2 replies
-
JAX's JIT requires that the number of returned arrays and their size be known at compile time. There is no way to make the function you describe jit-compatible, because the number of arrays returned and length of those arrays cannot be known at compile time. Often you can work around this constraint by changing the way you think about your larger algorithm (i.e. rather than focusing on a step that requires computing a dynamic number of dynamically-shaped arrays, you might find a different way of accomplishing your ultimate goal); if you add more information about what it is you are trying to accomplish, we may be able to give some suggestions along those lines. |
Beta Was this translation helpful? Give feedback.
-
yes, what I'm after is the Takagi decomposition, a.k.a. symmetric SVD: this is a special case of the SVD; it states that any square complex-valued symmetric (not hermitian) matrix with V unitary, and D containing the singular values of The decomposition is not part of scipy/numpy, and hence I don't expect to find it in jax.numpy and jax.scipy. There are efficient algorithms to code it, see this paper. However, it is also possible to code it up using the SVD/EIG but it's less efficient. I'm trying to write the decomposition in jax so I can run it on a GPU. I found two algorithms online: the first one is more efficient but it cannot be easily jit-ted; the second one is inefficient but can be jitted. def takagi(N, tol=1e-13, rounding=13):
r"""Autonne-Takagi decomposition of a complex symmetric (not Hermitian!) matrix.
Note that singular values of N are considered equal if they are equal after np.round(values, tol).
See :cite:`cariolaro2016` and references therein for a derivation.
Args:
N (array[complex]): square, symmetric matrix N
rounding (int): the number of decimal places to use when rounding the singular values of N
tol (float): the tolerance used when checking if the input matrix is symmetric: :math:`|N-N^T| <` tol
Returns:
tuple[array, array]: (rl, U), where rl are the (rounded) singular values,
and U is the Takagi unitary, such that :math:`N = U \diag(rl) U^T`.
"""
(n, m) = N.shape
if n != m:
raise ValueError("The input matrix must be square")
if np.linalg.norm(N - np.transpose(N)) >= tol:
raise ValueError("The input matrix is not symmetric")
N = np.real_if_close(N)
if np.allclose(N, 0):
return np.zeros(n), np.eye(n)
if np.isrealobj(N):
# If the matrix N is real one can be more clever and use its eigendecomposition
l, U = np.linalg.eigh(N)
vals = np.abs(l) # These are the Takagi eigenvalues
phases = np.sqrt(np.complex128([1 if i > 0 else -1 for i in l]))
Uc = U @ np.diag(phases) # One needs to readjust the phases
list_vals = [(vals[i], i) for i in range(len(vals))]
list_vals.sort(reverse=True)
sorted_l, permutation = zip(*list_vals)
permutation = np.array(permutation)
Uc = Uc[:, permutation]
# And also rearrange the unitary and values so that they are decreasingly ordered
return np.array(sorted_l), Uc
v, l, ws = np.linalg.svd(N)
w = np.transpose(np.conjugate(ws))
rl = np.round(l, rounding)
# Generate list with degenerancies
result = []
for k, g in groupby(rl):
result.append(list(g))
# Generate lists containing the columns that correspond to degenerancies
kk = 0
for k in result:
for ind, j in enumerate(k): # pylint: disable=unused-variable
k[ind] = kk
kk = kk + 1
# Generate the lists with the degenerate column subspaces
vas = []
was = []
for i in result:
vas.append(v[:, i])
was.append(w[:, i])
# Generate the matrices qs of the degenerate subspaces
qs = []
for i in range(len(result)):
qs.append(sla.sqrtm(np.transpose(vas[i]) @ was[i]))
# Construct the Takagi unitary
qb = sla.block_diag(*qs)
U = v @ np.conj(qb)
return rl, U This algorithm requires special care when handling degenerate singular values, whence my original question. The algorithm below however can be jitted. @jit
def lax_takagi(A):
"""Extremely simple and inefficient Takagi factorization of a
symmetric, complex matrix A. Here we take this to mean A = U D U^T
where D is a real, diagonal matrix and U is a unitary matrix. There
is no guarantee that it will always work. """
# Construct a Hermitian matrix.
H = A.T.conj() @ A
# Calculate the eigenvalue decomposition of the Hermitian matrix.
# The diagonal matrix in the Takagi factorization is the square
# root of the eigenvalues of this matrix.
u, lam = lax.linalg.eigh(H)
# The "almost" Takagi factorization. There is a conjugate here
# so the final form is as given in the doc string.
T = jnp.einsum('ji,jk,ki->i', u,A,u )
# T is diagonal but not real. That is easy to fix by a
# simple transformation which removes the complex phases
# from the resulting diagonal matrix.
c = jnp.exp(0.5j*jnp.angle(T))
U = jnp.einsum('ij,j->ij',u.conj(),c)
# Now A = np.dot(U, np.dot(np.diag(np.sqrt(lam)),U.T))
return jnp.sqrt(lam[::-1]), U[:,::-1] It'd be nice to have a proper state-of-the-art implementation for the Takagi decomposition in lax. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I want to extract the indices of the entries of a jax array, grouped according to their degeneracies, e.g.
I have come up with two different ways of doing this:
and
Unfortunately, none of the two ways can be placed in a function that first computes
E
and then extractsresult
from it, and is at the same time jit-table, because the operationsgroupby
andjnp.unique
explicitly depend on the size ofE
.Does anyone know how I can do this operation under just in time compilation using jax?
Beta Was this translation helpful? Give feedback.
All reactions