Replies: 1 comment 1 reply
-
Hi - thanks for the question. There are two problems with your implementation
Fixing these two problems, I can rewrite your code this way to make it compatible with def one_hot(idx):
num_classes = 2
idx_one_hot = jax.numpy.zeros(num_classes)
return idx_one_hot.at[idx].set(1)
arr = jax.numpy.array([0,1])
for i in arr:
print(one_hot(i))
# [1. 0.]
# [0. 1.]
print(jax.vmap(one_hot,in_axes=0,out_axes=0)(arr))
# [[1. 0.]
# [0. 1.]] See this doc page for more information on the * |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
sorry for my poor english!
i want to use vmap do my own function one_hot,code is here:
when i use
it work.but when i use it with vmap.like:
it error ,i dont know how to finish it.
sorry for poor english
Beta Was this translation helpful? Give feedback.
All reactions