Replies: 1 comment
-
I can not reproduce the error. from jax.example_libraries import stax
import jax
def f():
return stax.serial(
stax.Dense(128),
stax.Dropout(0.1),
stax.Dense(2)
)
init_fn, apply_fn = f()
x = jax.numpy.ones((2,))
output_shape, params = init_fn(jax.random.PRNGKey(0), x.shape)
y = apply_fn(params, x, rng=jax.random.PRNGKey(42))
print(y) # [-1.2811136 0.6775228] |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
sorry for my poor en.
when i use Dropout like that:
def xiaohua_dense():
return stax.serial(
stax.Dense(128),
stax.Dropout(0.1),
stax.Dense(2))
it will error,and i do not know what's wrong with it.
how can i finish it.
Beta Was this translation helpful? Give feedback.
All reactions