Batch.from_data_list() causing bottleneck - any ideas? #3645
Replies: 1 comment
-
I always thought that It's hard for me to tell how we can improve runtime (all we are doing is collecting lists of attributes and using Do you have a small example to reproduce the in-efficiency in your case? I'm happy to look into any bottlenecks. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I'm using pytorch geometric for reinforcement learning (I'm just using standard DQN) and my biggest computational bottleneck is coming from constructing the batches using Batch.from_data_list(), as shown in the profiling below.
Other than reducing the batch size, I'm looking for ways of improving the efficiency of this step.
When sampling experiences to construct the batch, an experience contains a 'state' and a 'next state' (both of type Data). I construct two separate Batches - one using a list of 'states' and one using a list of 'next states' (note that the states aren't from consecutive experiences so the 'states' and 'next states' lists aren't just offset by 1; they're completely unrelated).
One idea I had - I could construct a single Data object using more keys (x_state, edge_index_state, edge_attr_state, x_next_state, edge_index_next_state, edge_attr_next_state). This would then require only one call to from_data_list(), but with more keys the function might simply take twice as long to run.
Any thoughts about how to speed this up (or even if it's possible) would be much appreciated!
Beta Was this translation helpful? Give feedback.
All reactions