torch_geometric.transforms.AddRandomWalkPE doesn't support HeteroData() #6405
-
Most transformations like import torch
import torch_geometric.transforms as T
from torch_geometric.data import HeteroData, Data
transform = T.AddRandomWalkPE(walk_length=4, attr_name='pe')
nodes = torch.empty((5,6)).float()
edge_idx = torch.Tensor([[0,0,1,1,2,2,3,3],[1,2,0,3,0,3,1,2]]).long()
edge_idx_aug = torch.Tensor([[0,1,2,3], [0,0,0,0]]).long()
hetero_data=HeteroData()
hetero_data['place'].x = nodes[:-1,:]
hetero_data['room'].x = nodes[-1,:].view(-1, nodes_aug.shape[1])
hetero_data['place', 'with', 'place'].edge_index = edge_idx
hetero_data['place', 'connects', 'room'].edge_index = edge_idx_aug
post_data = transform(hetero_data) You will get:
One possible solution is to create a place_data = Data(x=nodes[:-1,:], edge_index=edge_idx)
post_data = transform(place_data)
hetero_data['place'].pe = post_data.pe Is there any better way to solve this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You are right
|
Beta Was this translation helpful? Give feedback.
You are right
AddRandomWalkPE
only supports homogenous graphs. The paper it is based on only discusses homogenous graphs.Approaches to tackle this are
place
nodes.