Modifying edges dynamically in DataBatch #6334
Replies: 4 comments 4 replies
-
There is currently no straightforward way to maintain indexing logic in modified batch2 = copy.copy(batch1)
batch2.edge_index = new_edge_index
batch2.edge_attr = new_edge_attr and modify |
Beta Was this translation helpful? Give feedback.
-
Yup that was the mistake. I had to compute the compressed index of the I am not sure what Here is the implementation that worked for me as of now:
Does this seem right? Can you clarify the use of |
Beta Was this translation helpful? Give feedback.
-
UPDATE: I needed to sort the final edge_idx list before constructing the new batch. The fixed code is marked with Unfortunately, my code is still not working. I suspect the edge indexing to still be off somehow. I am posting a full working example in case I missed something else. For context, I am trying to generate new graphs from an initial graph by 'flipping' some edges (i.e. add new edges and remove old ones) as part of an MCMC sampling algorithm. Since I have more than 10000 graphs in my training data, I would like to perform this sampling on a batch of graphs at once. So the batch version of my algorithm does the following:
I am trying to do everything using PyTorch tensors to maximize GPU utilization. Here is a working minimal example:
The above print statement always returns some graphs with random negative edges. Something like:
Here is the implementation:
|
Beta Was this translation helpful? Give feedback.
-
Yes exactly, I figured it out yesterday :) There is also another bug in this line: |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a DataBatch object made up of several graphs. In one step of my algorithm, I change the edges of the graphs (in their batch form). I would like to get a new object with the modified edge_index/edge_attr, and the full functionality of DataBatch. Here is an example:
Here are a few things I tried:
Solution 1:
With this solution, I cannot get a list of the graphs making up the batch via indexing (e.g.
graph2[graph2.batch[:3]]
)Solution 2:
In this case, indexing graph2 by batch index returns the same graphs as graph1 (i.e. the new edges do not reflect on the graphs making up the batch).
I would like to avoid using python lists to keep all my computations on the GPU.
Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions