Replies: 1 comment
-
|
Here's a draft PR where I implemented the solution above (haven't fixed tests as I am waiting to see if there is some feedback regarding the class constructor signature or overall idea. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
tldr: dynamic batching check number of nodes OR edges before adding a graph to the batch. I propose checking both the number of edges AND nodes.
Currently the dynamic batching method, which was based on this PR, takes in a$max_num$ integer that determines whether to stop adding graphs to the batch.
I copy the relevant lines here:
I want to update this class to take in both a max_node_budget and a max_edge_budget instead of what it currently accepts a maximum number of nodes or edges. This is based on an approach used by the Jraph library (from DeepMind but no longer being actively updated) in JAX (link). The reason being is that I'd like to make the algorithm more efficient by making sure the total memory in each batch required to store the edges/nodes/graphs is constant, otherwise the code will need to recompile on GPU. Right now if you specify the maximum number of nodes, it's possible that a batch of graphs has some X number of edges that, when including the padding, require memory equal to some power of 2. Another batch might have Y number of edges that cannot fit into that same power of 2, which would then require the code to recompile when run on GPU. This approach does have slightly more overhead (now we have two checks and two counters, max_num_nodes/max_num_edges) but the benefits of not recompiling outweigh this small overhead. Does that make sense?
I would propose the code looks like:
That said, I'm not if you want to maintain backwards compatibility. If so, I would propose that max_num can also be a tuple of (max_num_edges, max_num_nodes) and mode can also be set to 'both' so that it checks the nodes and edges. The code is uglier this way.
Love to hear your thoughts.
Beta Was this translation helpful? Give feedback.
All reactions