diff --git a/test/utils/test_softmax.py b/test/utils/test_softmax.py index 7102d3203f07..a7033e78c9f5 100644 --- a/test/utils/test_softmax.py +++ b/test/utils/test_softmax.py @@ -17,11 +17,13 @@ def test_softmax(): out = softmax(src, index) assert out.tolist() == [0.5, 0.5, 1, 1] assert softmax(src, ptr=ptr).tolist() == out.tolist() + assert softmax(src, index=index, ptr=ptr).tolist() == out.tolist() src = src.view(-1, 1) out = softmax(src, index) assert out.tolist() == [[0.5], [0.5], [1], [1]] assert softmax(src, ptr=ptr).tolist() == out.tolist() + assert softmax(src, index=index, ptr=ptr).tolist() == out.tolist() jit = torch.jit.script(softmax) assert torch.allclose(jit(src, index), out) diff --git a/torch_geometric/utils/_softmax.py b/torch_geometric/utils/_softmax.py index c6f19f8e4a0d..4de7c8432a5a 100644 --- a/torch_geometric/utils/_softmax.py +++ b/torch_geometric/utils/_softmax.py @@ -65,11 +65,14 @@ def softmax( size = ([1] * dim) + [-1] count = ptr[1:] - ptr[:-1] ptr = ptr.view(size) + output_size = index.shape[dim] if index is not None else None src_max = segment(src.detach(), ptr, reduce='max') - src_max = src_max.repeat_interleave(count, dim=dim) + src_max = src_max.repeat_interleave(count, dim=dim, + output_size=output_size) out = (src - src_max).exp() out_sum = segment(out, ptr, reduce='sum') + 1e-16 - out_sum = out_sum.repeat_interleave(count, dim=dim) + out_sum = out_sum.repeat_interleave(count, dim=dim, + output_size=output_size) elif index is not None: N = maybe_num_nodes(index, num_nodes) src_max = scatter(src.detach(), index, dim, dim_size=N, reduce='max')