Skip to content
Discussion options

You must be logged in to vote

If I understand the question correctly, I think something like this will do what you have in mind:

import jax
import jax.numpy as jnp
import numpy as np

def multi_kron(*vecs):
  result = 1
  for i, vec in enumerate(vecs[::-1]):
    result =  result * jnp.expand_dims(vec, range(1, 1 + i))
  return result.ravel()

np.random.seed(0)
x = np.random.rand(2)
y = np.random.rand(3)

# Check multi-kron against jnp.kron for two inputs
print(jnp.kron(x, y))
# [0.33080468 0.29903927 0.23250748 0.43108994 0.38969466 0.3029934 ]
print(multi_kron(x, y))
# [0.33080468 0.29903927 0.23250748 0.43108994 0.38969466 0.3029934 ]

# Check multi-kron for >2 inputs
vecs = [np.random.rand(N) for N in [2, 3, 4]]
print

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@LionSR
Comment options

@jakevdp
Comment options

@LionSR
Comment options

Answer selected by LionSR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants