Skip to content

feat: switch from Haiku to Equinox#21

Draft
panahiparham wants to merge 1 commit intoandnp:mainfrom
panahiparham:main
Draft

feat: switch from Haiku to Equinox#21
panahiparham wants to merge 1 commit intoandnp:mainfrom
panahiparham:main

Conversation

@panahiparham
Copy link
Contributor

In this work in progress PR I am working on switching from Haiku to Equinox. The main changes are,

  1. There is no longer an AgentState attribute, instead there are network and opt_state attributes (DQN also has a target_network)
  2. Now we use equinox and its corresponding filtered jax operations to handle neural nets creation and training

What we are currently missing,

  1. Anything except ReluNets (so mainly MinatarNet and AtariNet)
  2. Speed (Code runs a little slow compared to Haiku
  3. Type annotations for additions to DQN and EQRC

What I have verified so far,

  1. Run DQN and EQRC and MC and checked it reached good performance in 1 seed
  2. Checked that action values do not explode and do not turn Nan

Please let me know about the above decisions and if you think we should handle them in a different way. Please also point out areas of improvement for speed, code quality, and bugs.


if self.updates % self.target_refresh == 0:
self.state.target_params = self.state.params
self.target_network = copy(self.network)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an appropriate way to handle target nets?

import utils.chex as cxu

@cxu.dataclass
class AgentState:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about removing the AgentState? We can add it back in and store the model params alongside optimizer parameters. To perform forward passes we would use eqx.partition and eqx.combine a lot of times. Will it be slow?

assert isinstance(updates, dict)

decay = tree_map(
updates.heads['h'] = jax.tree.map(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This step should be verified carefully.

assert len(inputs) == 1
key_1, key_2 = jax.random.split(key, 2)

return eqx.nn.Sequential([
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, each of these can be its own class in utils/eqx.py instead of being a Sequential.

import jax
import equinox as eqx

class MultiHead(eqx.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about this? I wanted a multihead network to accomodate both QRC and DQN.

target_params=self.state.params,
optim=self.state.optim,
)
self.target_network = copy(self.network)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be deep copy instead of copy, so it copies the values of the object instead of the reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants