-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Open
Labels
Description
🐛 Describe the bug
torch_geometric.nn.parameter_dict.ParameterDict uses a set for its CLASS_ATTRS attribute. This is incompatible with torchscript, and prevents scripting of the HGTConv layer, which uses it
Minimal reproducible example:
from torch_geometric.nn import HGTConv
import torch
m = HGTConv(16, 16, ([], []))
s = torch.jit.script(m)Output:
...
TypeError:
'set' object in attribute 'ParameterDict.CLASS_ATTRS' is not a valid constant.
Valid constants are:
1. a nn.ModuleList
2. a value of type {bool, float, int, str, NoneType, torch.device, torch.layout, torch.dtype, torch.qscheme}
3. a list or tuple of (2)
This seems to be in conflict with this statement from the documentation:
From PyG 2.5 (and onwards), GNN layers are now fully compatible with torch.jit.script() without any modification needed
(https://pytorch-geometric.readthedocs.io/en/latest/notes/jit.html)
Versions
Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] nvtx==0.2.13
[pip3] pytorch-lightning==2.5.5
[pip3] torch==2.8.0+cu126
[pip3] torch-geometric==2.7.0
[pip3] torch_scatter==2.1.2+pt28cu126
[pip3] torchmetrics==1.8.2
[pip3] torchtext==0.18.0
[pip3] triton==3.4.0
[conda] numpy 1.24.2 pypi_0 pypi