Out Axes Specification when returning NamedTuples #8821
-
Hi! inputs = {
"cells": Array of shape (*, 3),
"mesh_pos": Array of shape (*,2),
"node_type": Array of shape (*,1),
"feature1": Array of shape (batch_size, tsteps, *, 2),
"feature2": Array of shape (batch_size, tsteps, *, 2)
} The I wanted to use JAX' def _build(inputs: dict):
# some fancy stuff happens
# build a graph tuple
graph = jraph.GraphsTuple(
nodes=nodes,
edges=edges,
receivers=receivers,
senders=senders,
globals=None,
n_node=n_node,
n_edge=n_edge
return graph
in_axes = {
"cells": None,
"mesh_pos": None,
"node_type": None,
"feature1": 0,
"feature2": 0,
build = jax.vmap(jax.vmap(_build, in_axes=in_axes), in_axes=in_axes)
The above code does exactly what I want, i.e., batching only over Is this somehow possible to specify using I am happy for every help I get and appreciate every single contribution! Thanks in advance! Best, J |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The out axes specification should have the same tree structure as the outputs. This means that if you're returning a out_axes = jraph.GraphsTuple(0, None, None, None, None, None, None) This is sometimes problematic for custom pytree implementations that do some sort of validation on their inputs; I don't know the best way to get around that issue. |
Beta Was this translation helpful? Give feedback.
The out axes specification should have the same tree structure as the outputs. This means that if you're returning a
GraphsTuple
, then the out_axes specification should also be aGraphsTuple
, so something like this:This is sometimes problematic for custom pytree implementations that do some sort of validation on their inputs; I don't know the best way to get around that issue.