-
First of all, thanks for building the great Jax ecosystem! Suppose I have a list of vectors specified by an array of the shape (N, 2), where N is the number of vectors and 2 is the fixed size of each vector. Now would it be possible to compute the Kronecker products of all the vectors with native Jax operations? The Kronecker products of two vectors of shapes [2] and [2] will give a vector of shape [4]. Therefore, with N vectors, one gets a vector of size 2^N. If the number of vectors N is only fixed by the input (variable size), then it seems that naively applying When I scrolled over the issues/discussions, I saw there are potentially some ways around by using I was wondering if there is an easy solution to this or if one would need to wait for the dynamic shape functionality to be released? Many thanks in advance. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
If I understand the question correctly, I think something like this will do what you have in mind: import jax
import jax.numpy as jnp
import numpy as np
def multi_kron(*vecs):
result = 1
for i, vec in enumerate(vecs[::-1]):
result = result * jnp.expand_dims(vec, range(1, 1 + i))
return result.ravel()
np.random.seed(0)
x = np.random.rand(2)
y = np.random.rand(3)
# Check multi-kron against jnp.kron for two inputs
print(jnp.kron(x, y))
# [0.33080468 0.29903927 0.23250748 0.43108994 0.38969466 0.3029934 ]
print(multi_kron(x, y))
# [0.33080468 0.29903927 0.23250748 0.43108994 0.38969466 0.3029934 ]
# Check multi-kron for >2 inputs
vecs = [np.random.rand(N) for N in [2, 3, 4]]
print(multi_kron(*vecs))
# [0.45602643 0.30463865 0.3271885 0.5331353 0.49278873 0.32919693
# 0.3535646 0.57611364 0.1960807 0.1309875 0.1406834 0.2292357
# 0.30895364 0.20638983 0.22166713 0.36119416 0.33385974 0.22302784
# 0.23953669 0.3903116 0.13284284 0.0887428 0.09531168 0.15530503] |
Beta Was this translation helpful? Give feedback.
If I understand the question correctly, I think something like this will do what you have in mind: