Skip to content

post_step clamping process doesn't preserve training labels in LabelPropagation #10627

@alrichardbollans

Description

@alrichardbollans

🐛 Describe the bug

The default post_step method for label propagation is to clamp all the values to [0,1]. As far as I can tell, this isn't discussed in the reference from which the implementation is derived (Qian Huang et al., 2020) (this reference also seems to derive the label propagation model from Dengyong Zhou et al., ‘Learning with Local and Global Consistency’, Advances in Neural Information Processing Systems 16 (2003).). Where this is discussed is in the original reference (Xiaojin Zhu and Zoubin Ghahramani, 2002), and it appears the motivation for this process is to maintain the values of the labelled data throughout the training process.

The current implementation of the default post_step doesn't achieve this preservation of labelled node values as it just makes sure the values at the training nodes are within [0,1]. Moreover, the post_step method only takes out as a parameter so it is not possible to access the train mask to reset values of the training nodes to their original values.

Finally, this clamping process also has the adverse effect of, in some cases, making evidence for different classes at unlabelled nodes be more similar than they should be.

This is all most evident in cases where lots of information is collected at a single node, as in the example:

import torch
from torch_geometric.nn.models import LabelPropagation

y = torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 1])
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 0, 0, 0, 0, 0]])
train_mask = torch.tensor([False, True, True, True, True, True, True, True, True])

model = LabelPropagation(num_layers=2, alpha=0.9)
out = model(y=y, edge_index=edge_index, mask=train_mask, edge_weight=torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1]))
print(out)

In this example, after the first forward pass the value is:

tensor([[1.0000, 1.0000],
        [0.1000, 0.0000],
        [0.1000, 0.0000],
        [0.1000, 0.0000],
        [0.1000, 0.0000],
        [0.1000, 0.0000],
        [0.1000, 0.0000],
        [0.0000, 0.1000],
        [0.0000, 0.1000]])

i.e. the unknown node ends up with the same evidence for class 0 and 1, which is counterintuitive as there are more edges from nodes with class 0. This is because it has aggregated all the information from the other nodes, which initially gives values over 1 for class 0 and 1, with larger values for class 0, which are then clamped to 1. Also, the values for the training/labelled nodes have been dramatically decreased and so in any future forward passes, the information from these nodes is a lot weaker.

It seems to me that the value of alpha, from (Zhou, 2003), helps to maintain the influence of the labeled/training nodes and so this clamping process isn't needed. That said, if you don't do anything in the post-step, the values explode as they are aggregated.

Versions

Environment

PyG version: 2.7.0
PyTorch version: 2.6.0
OS: linux
Python version: 3.12
CUDA/cuDNN version: 11.8
How you installed PyTorch and PyG (conda, pip, source): conda
Any other relevant information (e.g., version of torch-scatter):

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions