-
Notifications
You must be signed in to change notification settings - Fork 4k
Description
🐛 Describe the bug
The default post_step method for label propagation is to clamp all the values to [0,1]. As far as I can tell, this isn't discussed in the reference from which the implementation is derived (Qian Huang et al., 2020) (this reference also seems to derive the label propagation model from Dengyong Zhou et al., ‘Learning with Local and Global Consistency’, Advances in Neural Information Processing Systems 16 (2003).). Where this is discussed is in the original reference (Xiaojin Zhu and Zoubin Ghahramani, 2002), and it appears the motivation for this process is to maintain the values of the labelled data throughout the training process.
The current implementation of the default post_step doesn't achieve this preservation of labelled node values as it just makes sure the values at the training nodes are within [0,1]. Moreover, the post_step method only takes out as a parameter so it is not possible to access the train mask to reset values of the training nodes to their original values.
Finally, this clamping process also has the adverse effect of, in some cases, making evidence for different classes at unlabelled nodes be more similar than they should be.
This is all most evident in cases where lots of information is collected at a single node, as in the example:
import torch
from torch_geometric.nn.models import LabelPropagation
y = torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 1])
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 0, 0, 0, 0, 0]])
train_mask = torch.tensor([False, True, True, True, True, True, True, True, True])
model = LabelPropagation(num_layers=2, alpha=0.9)
out = model(y=y, edge_index=edge_index, mask=train_mask, edge_weight=torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1]))
print(out)In this example, after the first forward pass the value is:
tensor([[1.0000, 1.0000],
[0.1000, 0.0000],
[0.1000, 0.0000],
[0.1000, 0.0000],
[0.1000, 0.0000],
[0.1000, 0.0000],
[0.1000, 0.0000],
[0.0000, 0.1000],
[0.0000, 0.1000]])i.e. the unknown node ends up with the same evidence for class 0 and 1, which is counterintuitive as there are more edges from nodes with class 0. This is because it has aggregated all the information from the other nodes, which initially gives values over 1 for class 0 and 1, with larger values for class 0, which are then clamped to 1. Also, the values for the training/labelled nodes have been dramatically decreased and so in any future forward passes, the information from these nodes is a lot weaker.
It seems to me that the value of alpha, from (Zhou, 2003), helps to maintain the influence of the labeled/training nodes and so this clamping process isn't needed. That said, if you don't do anything in the post-step, the values explode as they are aggregated.
Versions
Environment
PyG version: 2.7.0
PyTorch version: 2.6.0
OS: linux
Python version: 3.12
CUDA/cuDNN version: 11.8
How you installed PyTorch and PyG (conda, pip, source): conda
Any other relevant information (e.g., version of torch-scatter):