-
Suppose that I have a list l = [[a, b, c], [d, e], [f]] where a,b,c,d,e,f are numbers How to create a jax matrix
? Of course |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Well... the answer is obvious, we can pad the element of the list |
Beta Was this translation helpful? Give feedback.
-
I can think of several ways to do this in JAX; here are some examples: import jax.numpy as jnp
values = [jnp.array([1., 2., 3.]),
jnp.array([4., 5.]),
jnp.array([6.])]
dim = len(values[0])
# Option 1: pad then
arr1 = jnp.stack([jnp.pad(row, (dim - len(row), 0))
for row in values])
print(arr1)
# [[1. 2. 3.]
# [0. 4. 5.]
# [0. 0. 6.]]
# Option 2: insert
arr2 = jnp.zeros((dim, dim))
for i, row in enumerate(values):
arr2 = arr2.at[i, i:].set(row)
print(arr2)
# [[1. 2. 3.]
# [0. 4. 5.]
# [0. 0. 6.]]
# Option 3: triangular indices
arr3 = jnp.zeros((dim, dim)).at[jnp.triu_indices(dim)].set(jnp.concatenate(values))
print(arr3)
# [[1. 2. 3.]
# [0. 4. 5.]
# [0. 0. 6.]] |
Beta Was this translation helpful? Give feedback.
Well... the answer is obvious, we can pad the element of the list
l
by zeros.