Replies: 2 comments
-
The forward input to |
Beta Was this translation helpful? Give feedback.
0 replies
-
I know I'm late for the party, but for anyone struggling with this, I created the following function: def transform_hetero_data(data: HeteroData):
# ========================
# Nodes
# ========================
# Extract node features from the input data dictionary
x_dict = data.x_dict
# Determine the maximum number of features (columns) across all node types
max_node_cols = max([t.size(1) for t in x_dict.values()])
# Initialize lists to store node features and node types
x = []
node_type = []
# Loop through each node type and pad the features to match the max_node_cols
for type, t in enumerate(x_dict.values()):
# Calculate padding width required for each node type to match max_node_cols
pad_width = max_node_cols - t.size(1)
# Pad node features with zeros (constant padding) to ensure all have the same column size
t = F.pad(t, (0, pad_width), mode='constant', value=0.0)
# Append padded node features to the x list
x.append(t)
# Create a tensor to denote the node type for each node in the current batch
node_type.append(torch.full((t.size(0),), type))
# Concatenate all node features into a single tensor
node_type = torch.cat(node_type)
x = torch.cat(x)
# ========================
# Edge Attributes
# ========================
# Extract edge attributes from the input data dictionary
edge_attr_dict = data.edge_attr_dict
# Determine the maximum number of edge features (columns) across all edge types
max_edge_cols = max([t.size(1) for t in edge_attr_dict.values()])
# Initialize lists to store edge attributes and edge types
edge_attr = []
edge_type = []
# Loop through each edge type and pad the features to match the max_edge_cols
for type, t in enumerate(edge_attr_dict.values()):
# Calculate padding width required for each edge type to match max_edge_cols
pad_width = max_edge_cols - t.size(1)
# Pad edge features with zeros (constant padding) to ensure all have the same column size
t = F.pad(t, (0, pad_width), mode='constant', value=0.0)
# Append padded edge features to the edge_attr list
edge_attr.append(t)
# Create a tensor to denote the edge type for each edge in the current batch
edge_type.append(torch.full((t.size(0),), type))
# Concatenate all edge attributes into a single tensor
edge_attr = torch.cat(edge_attr)
edge_type = torch.cat(edge_type)
# ========================
# Edges
# ========================
# Extract edge indices from the input data dictionary
edge_dict = data.edge_index_dict
# Concatenate all edge indices (from different edge types) into a single tensor
edge_index = torch.cat(list(edge_dict.values()), dim=1)
# Return the final processed data: node features, edge indices, node types, edge types, and edge attributes
return x, edge_index, node_type, edge_type, edge_attr It's not the best solution, but it works for me ... hope it helps :) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm having issues with implementing HEATConv, mainly difficulty understanding input parameters for HEATConv and the forward function. Could not find any sample project or tutorial online.
Following is my implementation, but bellow error occurs,
/usr/local/lib/python3.7/dist-packages/torch_geometric/nn/dense/linear.py in forward(self, x, type_vec)
272 else:
273 assert self.lins is not None
--> 274 out = x.new_empty(x.size(0), self.out_channels)
275 for i, lin in enumerate(self.lins):
276 mask = type_vec == i
AttributeError: 'dict' object has no attribute 'new_empty'
Code :
class HEAT(nn.Module):
def init(self, in_channels: Union[int, Dict[str, int]],out_channels: int, hidden_channels=128, heads=8, num_node_types = 3, num_edge_types = 2, edge_type_emb_dim = 1, edge_dim = 1, edge_attr_emb_dim = 1):
super().init()
self.heat_conv = HEATConv(in_channels, hidden_channels, heads=heads,dropout=0.6, metadata=data.metadata(),num_node_types=num_node_types, num_edge_types=num_edge_types, edge_type_emb_dim=edge_type_emb_dim, edge_dim=edge_dim, edge_attr_emb_dim=edge_attr_emb_dim)
self.lin = nn.Linear(hidden_channels, out_channels)
model.txt
Beta Was this translation helpful? Give feedback.
All reactions