How to use HeteroBatchNorm and HeteroLayerNorm? #9182
-
I noticed that HeteroBatchNorm requires x and type_vec as inputs, and HeteroLayerNorm requires x, type_vec and type_ptr as inputs.
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
This refers to the same data just in a different format. If you are using dictionaries to process heterogeneous graphs, you can just do norms = torch.nn.ModuleDict({node_type: BatchNorm(...) for node_type in node_types})
x_dict = {key: norms[key](x) for x, key in x_dict.items()} |
Beta Was this translation helpful? Give feedback.
This refers to the same data just in a different format.
type_vec
refers to the classic[0, 0, 0, 1, 1, 1, 2, 2, 2, ...]
index representation which assigns each node to a specific node type.type_ptr
refers to its compressed representation, i.e., it only stores its boundaries, e.g.:[0, 4, 7]
for two node types where node type 1 has 4 nodes and node type 2 has 3 nodes.If you are using dictionaries to process heterogeneous graphs, you can just do