-
Hi Matt, thank you for your help. Now I have working model with one class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = SplineConv(1, 32, dim=2, kernel_size=5) # d.num_features:=1=in_channels, 32: out_channels
self.conv2 = SplineConv(32, 16, dim=2, kernel_size=5)
self.conv3 = SplineConv(16, 32, dim=2, kernel_size=5)
self.conv4 = SplineConv(32, 1, dim=2, kernel_size=5)
def forward(self, data):
data.x = data.x.to(torch.float32)
data.y = data.y.to(torch.float32)
data.edge_attr = data.edge_attr.to(torch.float32)
data.pos = data.pos.to(torch.float32)
data.x = torch.tanh(self.conv1(data.x, data.edge_index, data.edge_attr))
weight1 = normalized_cut_2d(data.edge_index, data.pos)
cluster1 = graclus(data.edge_index, weight1, data.x.size(0))
coarsened_data_1 = max_pool(cluster1, data, transform=transform)
coarsened_data_1.x = torch.tanh(self.conv2(coarsened_data_1.x, coarsened_data_1.edge_index, coarsened_data_1.edge_attr))
weight2 = normalized_cut_2d(coarsened_data_1.edge_index, coarsened_data_1.pos)
cluster2 = graclus(coarsened_data_1.edge_index, weight2, coarsened_data_1.x.size(0))
coarsened_data_2 = max_pool(cluster2, coarsened_data_1, transform=transform)
cluster2, perm2 = self.consecutive_cluster(cluster2)
upsampled_x_2 = coarsened_data_2.x[cluster2]
x = upsampled_x_2
x = self.conv3(x, coarsened_data_2.edge_index, coarsened_data_2.edge_attr)
cluster1, perm1 = self.consecutive_cluster(cluster1)
upsampled_x_1 = coarsened_data_1.x[cluster1]
x = self.conv4(upsampled_x_1, data.edge_index, data.edge_attr)
x = x.view(-1)
return x.to(torch.float32) I know my code looks dirty (sorry >:), but I wanted to have a model like this. Thanks. The error message is like this. RuntimeError Traceback (most recent call last)
<ipython-input-11-a7d340e7534e> in <module>
1 for epoch in range(1, 6): # initial epoch = 31
----> 2 train(epoch)
3 test_acc = test()
4 print('Epoch: {:02d}, Test MSE: {:.4f}'.format(epoch, test_acc))
<ipython-input-10-0cc1431bb645> in train(epoch)
18 recon = data.x.to(torch.float32)
---> 19 pred = model(data)
~/miniconda3/envs/gnb_df_pyg/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
<ipython-input-8-efa986e8e223> in forward(self, data)
---> 56 x = self.conv4(upsampled_x_1, data.edge_index, data.edge_attr)
57
58
~/miniconda3/envs/gnb_df_pyg/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~/miniconda3/envs/gnb_df_pyg/lib/python3.7/site-packages/torch_geometric/nn/conv/spline_conv.py in forward(self, x, edge_index, edge_attr, size)
120
121 # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
--> 122 out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
123
124 x_r = x[1]
~/miniconda3/envs/gnb_df_pyg/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py in propagate(self, edge_index, size, **kwargs)
235
236 msg_kwargs = self.inspector.distribute('message', coll_dict)
--> 237 out = self.message(**msg_kwargs)
238
239 # For `GNNExplainer`, we require a separate message and aggregate
~/miniconda3/envs/gnb_df_pyg/lib/python3.7/site-packages/torch_geometric/nn/conv/spline_conv.py in message(self, x_j, edge_attr)
134 data = spline_basis(edge_attr, self.kernel_size, self.is_open_spline,
135 self.degree)
--> 136 return spline_weighting(x_j, self.weight, *data)
137
138 def __repr__(self):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/home/hlx/miniconda3/envs/gnb_df_pyg/lib/python3.7/site-packages/torch_spline_conv/weighting.py", line 8, in spline_weighting
basis: torch.Tensor,
weight_index: torch.Tensor) -> torch.Tensor:
return torch.ops.torch_spline_conv.spline_weighting(
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
x, weight, basis, weight_index)
RuntimeError: x.size(1) == weight.size(1) INTERNAL ASSERT FAILED at "/home/travis/build/rusty1s/pytorch_spline_conv/csrc/cuda/weighting_cuda.cu":47, please report a bug to PyTorch. Input mismatch |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 25 replies
-
I guess you need to upsample upsampled_x_1 = x[cluster1] instead of upsampled_x_1 = coarsened_data_1.x[cluster1] |
Beta Was this translation helpful? Give feedback.
I guess you need to upsample
x
from the previous layer, e.g.:instead of