-
I want to store two graphs in a single Data object (PairData), as mentioned in the PyTorch Geometric Documentation (part of ADVANCED MINI-BATCHING). Then, I am trying to create a new database class consisting of PairData objects, from InMemoryDataset class. However, some confusing errors appeared during the code execution. I use DataLoader to load the data and train the graph neural network NNConv (imitating the process of Example1:Neural Message Passing for Quantum Chemistry (ICML 2017)) When the code runs to here: path = osp.join(osp.dirname(osp.realpath(__file__)), 'data')
dataset = ILs2Dataset(root=path).shuffle() errors are reported: Traceback (most recent call last):
File "main2.py", line 27, in <module>
dataset = ILs2Dataset(root=path).shuffle()
File "D:\try\data.py", line 107, in __init__
super().__init__(root, transform, pre_transform, pre_filter)
File "D:\software\anaconda3\envs\pyg\lib\site-packages\torch_geometric\data\in_memory_dataset.py", line 54, in __init__
pre_filter)
File "D:\software\anaconda3\envs\pyg\lib\site-packages\torch_geometric\data\dataset.py", line 92, in __init__
self._process()
File "D:\software\anaconda3\envs\pyg\lib\site-packages\torch_geometric\data\dataset.py", line 165, in _process
self.process()
File "D:\try\data.py", line 174, in process
torch.save(self.collate(data_list), self.processed_paths[0])
File "D:\software\anaconda3\envs\pyg\lib\site-packages\torch_geometric\data\in_memory_dataset.py", line 104, in collate
data = data_list[0].__class__()
TypeError: __init__() missing 7 required positional arguments: 'x_ca', 'edge_index_ca', 'edge_attr_ca', 'x_an', 'edge_index_an', 'edge_attr_an', and 'y' D:\software\anaconda3\envs\pyg\lib\site-packages\torch_geometric\data\in_memory_dataset.py in collate(data_list)
102 format of :class:`torch_geometric.data.InMemoryDataset`."""
103 keys = data_list[0].keys
--> 104 data = data_list[0].__class__()
105
106 for key in keys:
TypeError: __init__() missing 7 required positional arguments: 'x_ca', 'edge_index_ca', 'edge_attr_ca', 'x_an', 'edge_index_an', 'edge_attr_an', and 'y' I think it may be necessary to modify the import pandas as pd
from tqdm import tqdm
import torch
from torch_geometric.data import (InMemoryDataset, Data)
class PairData(Data):
""" Paired data type. Each object has 2 graphs.
ca represents cation, an represents anion.
"""
def __init__(self, x_ca, edge_index_ca, edge_attr_ca,
x_an, edge_index_an, edge_attr_an, y):
super(IonPairData, self).__init__()
self.x_ca = x_ca
self.edge_index_ca = edge_index_ca
self.edge_attr_ca = edge_attr_ca
self.x_an = x_an
self.edge_index_an = edge_index_an
self.edge_attr_an = edge_attr_an
self.y = y
def __inc__(self, key, value):
if key == 'edge_index_ca':
return self.x_ca.size(0)
if key == 'edge_index_an':
return self.x_an.size(0)
else:
return super().__inc__(key, value) class ILs2Dataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None,
pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
def mean(self, target):
y = torch.cat([self.get(i).y for i in range(len(self))], dim=0)
return y[:, target].mean().item()
def std(self, target):
y = torch.cat([self.get(i).y for i in range(len(self))], dim=0)
return y[:, target].std().item()
@property
def raw_file_names(self):
return 'raw_ils.csv'
@property
def processed_file_names(self):
return 'data.pt'
def download(self):
pass
def process(self):
df = pd.read_csv(self.raw_paths[0], header=None)
df.loc[:, 3:] = df.loc[:, 3:].astype(float)
raw_data = df.values.tolist()
data_list = []
for _raw in tqdm(raw_data):
### The code for data extraction here is omitted. ###
data = PairData(x_ca=x_ca,
edge_index_ca=edge_index_ca,
edge_attr_ca=edge_attr_ca,
x_an=x_an,
edge_index_an=edge_index_an,
edge_attr_an=edge_attr_an,
y=y)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
torch.save(self.collate(data_list), self.processed_paths[0]) I also tried to create a new database class from Dataset class. But when the code runs to the training set2set layer, the following error will be reported. torch_geometric\nn\glob\set2set.py", line 50, in forward
batch_size = batch.max().item() + 1
AttributeError: 'NoneType' object has no attribute 'max' Thank you very much for reading and look forward to your reply! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
After I modified the def __init__(self, x_ca=None, edge_index_ca=None, edge_attr_ca=None,
x_an=None, edge_index_an=None, edge_attr_an=None, y=None):
super(PairData, self).__init__() However, when the code runs to the training set2set layer, the following error will be reported. Traceback (most recent call last):
File "main2.py", line 141, in <module>
loss = train(epoch)
File "main2.py", line 121, in train
loss = F.mse_loss(model(data), data.y)
File "D:\software\anaconda3\envs\pyg\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "main2.py", line 95, in forward
out_ca = self.set2set_ca(out_ca, data.batch)
File "D:\software\anaconda3\envs\pyg\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\software\anaconda3\envs\pyg\lib\site-packages\torch_geometric\nn\glob\set2set.py", line 50, in forward
batch_size = batch.max().item() + 1
AttributeError: 'NoneType' object has no attribute 'max' Thank you very much for reading and look forward to your reply! |
Beta Was this translation helpful? Give feedback.
After I modified the
__init__()
function of thePairData
class, the previous error was gone.However, when the code runs to the training set2set layer, the following error will be reported.