-
Hi, I encounter a problem when using torch_geometric. Below is an example: I also try to make the elements in the list be the long tenor but it doesn't work because of the different lengths. I also see the implementation in Thanks in advance! import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
class MyData(Data):
def __init__(self, x, indices):
super().__init__()
self.x = x
self.indices = indices
def __inc__(self, key, value, *args, **kwargs):
if key == 'indices':
return self.x.size(0)
else:
return super().__inc__(key, value, *args, **kwargs)
x = torch.randn(7, 1)
indices = [[0, 1, 2, 3], [2, 4], [5, 6]]
d = MyData(x, indices)
data_list = [d, d]
loader = DataLoader(data_list, batch_size=2)
print(next(iter(loader)).indices) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Yeah, you are right. One workaround is to save import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
class MyData(Data):
def __init__(self, x, indices):
super().__init__()
self.x = x
self.indices = indices
def __inc__(self, key, value, *args, **kwargs):
if key == 'indices':
return self.x.size(0)
else:
return super().__inc__(key, value, *args, **kwargs)
x = torch.randn(7, 1)
indices = [[0, 1, 2, 3], [2, 4], [5, 6]]
d = MyData(x, indices)
data_list = [d, d]
loader = DataLoader(data_list, batch_size=2)
print(next(iter(loader)).indices) This should also be way faster than working with Python lists. |
Beta Was this translation helpful? Give feedback.
Yeah, you are right.
Data.__inc__
is currently only applied to PyTorch tensors. We probably need to at least omit a warning in case a user wants to increment a non-incremental object. Let me know if you have interest in adding this :)One workaround is to save
indices
as a one-dimensionalTensor
and to utilizetorch.Tensor.split
after batching: