Global graph features #2823
-
Hello, I was wondering on how to handle general features (or context), i.e. features that are shared between all the nodes in a graph. To be clear, I am working on a Reinforcement Learning task, where my state is represented by a graph. Each node represents an entity that has features but I have also "global features" that are related to the current state (e.g. the number of steps left, current score...) and concern every node. I was thinking of simply processing these global features in MLPs and concatenating the resulting features with all node features, but maybe it is suboptimal. Do you know a smarted way of handling this global input? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I think your suggested approach is the most prominent one. You can easily convert global features to node-wise features during training: data.global_features = ...
data.x = torch.cat([data.x, data.global_features[data.batch]], dim=-1) |
Beta Was this translation helpful? Give feedback.
I think your suggested approach is the most prominent one. You can easily convert global features to node-wise features during training: