Skip to content
Discussion options

You must be logged in to vote

Since FlatMapping doesn't implement __setitem__, one way to vary the contents is to flatten, modify, and unflatten the parameters. For example:

from jax import tree_util
params_flat, tree = tree_util.tree_flatten(params)
params_flat[1] = jnp.array([1., -1.])
new_params = tree_util.tree_unflatten(tree, params_flat)
print(new_params)
# FlatMapping({
#   'my_module': FlatMapping({
#                  'b': DeviceArray([1.], dtype=float32),
#                  'w': DeviceArray([ 1., -1.], dtype=float32),
#                }),
# })

Or, if you prefer, you can create your parameters from scratch using a Python dict and hk.data_structures.to_immutable_dict:

params = {
  'my_module': {
    'w': jnp.array

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@cottrell
Comment options

@jakevdp
Comment options

@cottrell
Comment options

Answer selected by cottrell
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants