-
Hi, I learned from the instruction about creating mini-batches for paired graphs. However, when I use the DataLoader on a data_list composed of batched graphs, it would error. Is it possible to use the dataloader to batch a list of batched graphs? For example, each single graph from torch_geometric.data import Data, DataLoader, Batch
import torch
import torch.nn.functional as F
from copy import deepcopy
class PairData(Data):
def __init__(self, edge_index_s=None, x_s=None, edge_index_t=None, x_t=None):
super(PairData, self).__init__()
self.edge_index_s = edge_index_s
self.x_s = x_s
self.edge_index_t = edge_index_t
self.x_t = x_t
def __inc__(self, key, value):
if key == 'edge_index_s':
return self.x_s.size(0)
if key == 'edge_index_t':
return self.x_t.size(0)
else:
return super().__inc__(key, value)
if __name__ == '__main__':
edge_index_s = torch.tensor([
[0, 0, 0, 0],
[1, 2, 3, 4],
])
x_s = torch.randn(5, 16) # 5 nodes.
edge_index_t = torch.tensor([
[0, 0, 0],
[1, 2, 3],
])
x_t = torch.randn(4, 16) # 4 nodes.
data = PairData(edge_index_s, x_s, edge_index_t, x_t)
bs = 10
# suppose a graph is composed by a random number of batched graphs
inner_bs = torch.randint(1, 3, (bs, 1))
data_list = []
for i in range(bs):
inner_data_list = []
for _ in range(inner_bs[i]):
inner_data_list.append(deepcopy(data))
inner_data_batch = Batch.from_data_list(inner_data_list, follow_batch=['x_s', 'x_t'])
data_list.append(inner_data_batch)
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
The output is Traceback (most recent call last):
File "C:\ProgramData\Anaconda3\envs\rayNew\lib\site-packages\torch\utils\data\dataloader.py", line 517, in __next__
data = self._next_data()
File "C:\ProgramData\Anaconda3\envs\rayNew\lib\site-packages\torch\utils\data\dataloader.py", line 557, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "C:\ProgramData\Anaconda3\envs\rayNew\lib\site-packages\torch\utils\data\_utils\fetch.py", line 47, in fetch
return self.collate_fn(data)
File "C:\ProgramData\Anaconda3\envs\rayNew\lib\site-packages\torch_geometric\data\dataloader.py", line 36, in __call__
return self.collate(batch)
File "C:\ProgramData\Anaconda3\envs\rayNew\lib\site-packages\torch_geometric\data\dataloader.py", line 16, in collate
return Batch.from_data_list(batch, self.follow_batch,
File "C:\ProgramData\Anaconda3\envs\rayNew\lib\site-packages\torch_geometric\data\batch.py", line 108, in from_data_list
cumsum[key].append(inc + cumsum[key][-1])
TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'
python-BaseException It seems this line would return ```None`` for |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I see. The problem with that is that the from torch_geometric.data import Data, DataLoader, Batch
import torch
from copy import deepcopy
class PairData(Data):
def __init__(self, **kwargs): # We allow arbitrary arguments
super(PairData, self).__init__(**kwargs)
def __inc__(self, key, value):
if key == 'edge_index_s':
return self.x_s.size(0)
if key == 'edge_index_t':
return self.x_t.size(0)
else:
return super().__inc__(key, value)
if __name__ == '__main__':
edge_index_s = torch.tensor([
[0, 0, 0, 0],
[1, 2, 3, 4],
])
x_s = torch.randn(5, 16) # 5 nodes.
edge_index_t = torch.tensor([
[0, 0, 0],
[1, 2, 3],
])
x_t = torch.randn(4, 16) # 4 nodes.
data = PairData(edge_index_s=edge_index_s, x_s=x_s,
edge_index_t=edge_index_t, x_t=x_t)
bs = 10
# suppose a graph is composed by a random number of batched graphs
inner_bs = torch.randint(1, 3, (bs, 1))
data_list = []
for i in range(bs):
inner_data_list = []
for _ in range(inner_bs[i]):
inner_data_list.append(deepcopy(data))
inner_data_batch = Batch.from_data_list(inner_data_list,
follow_batch=['x_s', 'x_t'])
inner_data_batch = PairData(**inner_data_batch.to_dict()) # Convert back to `PairData`
data_list.append(inner_data_batch)
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch) |
Beta Was this translation helpful? Give feedback.
I see. The problem with that is that the
Batch
class does not have access to thePairData.__inc__
method. I currently do not see a way to allow that. I think converting yourBatch
back to aPairData
object yields the most elegant formulation to fix this: