Skip to content

Commit fbafbc4

Browse files
authored
Allow optional but untyped tensors in MessagePassing (#9494)
Fixes #9492
1 parent c8cd4de commit fbafbc4

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3636

3737
### Changed
3838

39+
- Allow optional but untyped tensors in `MessagePassing` ([#9494](https://github.com/pyg-team/pytorch_geometric/pull/9494))
3940
- Added support for modifying `filename` of the stored partitioned file in `ClusterLoader` ([#9448](https://github.com/pyg-team/pytorch_geometric/pull/9448))
4041
- Support other than two-dimensional inputs in `AttentionalAggregation` ([#9433](https://github.com/pyg-team/pytorch_geometric/pull/9433))
4142
- Improved model performance of the `examples/ogbn_papers_100m.py` script ([#9386](https://github.com/pyg-team/pytorch_geometric/pull/9386), [#9445](https://github.com/pyg-team/pytorch_geometric/pull/9445))

test/nn/conv/test_message_passing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,24 @@ def test_pickle(tmp_path):
740740

741741
model = torch.load(path)
742742
torch.jit.script(model)
743+
744+
745+
class MyOptionalEdgeAttrConv(MessagePassing):
746+
def __init__(self):
747+
super().__init__()
748+
749+
def forward(self, x, edge_index, edge_attr=None):
750+
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
751+
752+
def message(self, x_j, edge_attr=None):
753+
return x_j if edge_attr is None else x_j * edge_attr.view(-1, 1)
754+
755+
756+
def test_my_optional_edge_attr_conv():
757+
conv = MyOptionalEdgeAttrConv()
758+
759+
x = torch.randn(4, 8)
760+
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
761+
762+
out = conv(x, edge_index)
763+
assert out.size() == (4, 8)

torch_geometric/nn/conv/collect.jinja

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,16 @@ def {{collect_name}}(
9898

9999
{%- if 'edge_weight' in collect_param_dict and
100100
collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %}
101-
assert edge_weight is not None
101+
if torch.jit.is_scripting():
102+
assert edge_weight is not None
102103
{%- elif 'edge_attr' in collect_param_dict and
103104
collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %}
104-
assert edge_attr is not None
105+
if torch.jit.is_scripting():
106+
assert edge_attr is not None
105107
{%- elif 'edge_type' in collect_param_dict and
106108
collect_param_dict['edge_type'].type_repr.endswith('Tensor') %}
107-
assert edge_type is not None
109+
if torch.jit.is_scripting():
110+
assert edge_type is not None
108111
{%- endif %}
109112

110113
# Collect user-defined arguments:

0 commit comments

Comments
 (0)