Conversation
Breaking Change
|
|
||
| if self.updates % self.target_refresh == 0: | ||
| self.state.target_params = self.state.params | ||
| self.target_network = copy(self.network) |
There was a problem hiding this comment.
Is this an appropriate way to handle target nets?
| import utils.chex as cxu | ||
|
|
||
| @cxu.dataclass | ||
| class AgentState: |
There was a problem hiding this comment.
What do you think about removing the AgentState? We can add it back in and store the model params alongside optimizer parameters. To perform forward passes we would use eqx.partition and eqx.combine a lot of times. Will it be slow?
| assert isinstance(updates, dict) | ||
|
|
||
| decay = tree_map( | ||
| updates.heads['h'] = jax.tree.map( |
There was a problem hiding this comment.
This step should be verified carefully.
| assert len(inputs) == 1 | ||
| key_1, key_2 = jax.random.split(key, 2) | ||
|
|
||
| return eqx.nn.Sequential([ |
There was a problem hiding this comment.
Alternatively, each of these can be its own class in utils/eqx.py instead of being a Sequential.
| import jax | ||
| import equinox as eqx | ||
|
|
||
| class MultiHead(eqx.Module): |
There was a problem hiding this comment.
What do you think about this? I wanted a multihead network to accomodate both QRC and DQN.
| target_params=self.state.params, | ||
| optim=self.state.optim, | ||
| ) | ||
| self.target_network = copy(self.network) |
There was a problem hiding this comment.
Maybe this should be deep copy instead of copy, so it copies the values of the object instead of the reference.
In this work in progress PR I am working on switching from Haiku to Equinox. The main changes are,
AgentStateattribute, instead there arenetworkandopt_stateattributes (DQN also has atarget_network)What we are currently missing,
DQNandEQRCWhat I have verified so far,
Please let me know about the above decisions and if you think we should handle them in a different way. Please also point out areas of improvement for speed, code quality, and bugs.