-
If you follow a typicaly haiku setup import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
class MyModule(hk.Module):
def __init__(self, output_size, name=None):
super().__init__(name=name)
self.output_size = output_size
def __call__(self, x):
j, k = x.shape[-1], self.output_size
w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
return jnp.dot(x, w) + b
def f_fn(x):
a = MyModule(output_size=1)
return a(x)
f = hk.without_apply_rng(hk.transform(f_fn))
x = np.random.randn(10, 2).astype(np.float32)
rng = jax.random.PRNGKey(0)
params = f.init(rng, x)
yp = f.apply(params, x) So you have params like this:
If you want to, for example, vary 'w' and plot the result, how do you assemble the pytree for the evaluation? I don't examples in the jax nor the haiku docs. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Since from jax import tree_util
params_flat, tree = tree_util.tree_flatten(params)
params_flat[1] = jnp.array([1., -1.])
new_params = tree_util.tree_unflatten(tree, params_flat)
print(new_params)
# FlatMapping({
# 'my_module': FlatMapping({
# 'b': DeviceArray([1.], dtype=float32),
# 'w': DeviceArray([ 1., -1.], dtype=float32),
# }),
# }) Or, if you prefer, you can create your parameters from scratch using a Python dict and params = {
'my_module': {
'w': jnp.array([1., 2.]),
'b': jnp.array([1.])
}
}
hk.data_structures.to_immutable_dict(params)
# FlatMapping({
# 'my_module': FlatMapping({
# 'w': DeviceArray([1., 2.], dtype=float32),
# 'b': DeviceArray([1.], dtype=float32),
# }),
# }) (the opposite function, |
Beta Was this translation helpful? Give feedback.
Since
FlatMapping
doesn't implement__setitem__
, one way to vary the contents is to flatten, modify, and unflatten the parameters. For example:Or, if you prefer, you can create your parameters from scratch using a Python dict and
hk.data_structures.to_immutable_dict
: