Skip to content

Commit 9157aa9

Browse files
fix
1 parent 2319a89 commit 9157aa9

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

src/GNNGraphs/gatherscatter.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@ _gather(x::Tuple, i) = map(x -> _gather(x, i), x)
33
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
44
_gather(x::Nothing, i) = nothing
55

6-
_scatter(aggr, m::Nothing, t; dstsize=nothing) = nothing
7-
_scatter(aggr, m::NamedTuple, t; dstsize=nothing) = map(m -> _scatter(aggr, m, t; dstsize), m)
8-
_scatter(aggr, m::Tuple, t; dstsize=nothing) = map(m -> _scatter(aggr, m, t; dstsize), m)
9-
function _scatter(aggr, m::AbstractMatrix, t; dstsize=nothing)
10-
if dstsize !== nothing
11-
dstsize = (size(m,1), dstsize)
12-
end
13-
NNlib.scatter(aggr, m, t; dstsize)
6+
_scatter(aggr, src::Nothing, idx, n) = nothing
7+
_scatter(aggr, src::NamedTuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src)
8+
_scatter(aggr, src::Tuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src)
9+
10+
function _scatter(aggr,
11+
src::AbstractArray,
12+
idx::AbstractVector{<:Integer},
13+
n::Integer)
14+
15+
dstsize = (size(src)[1:end-1]..., n)
16+
NNlib.scatter(aggr, src, idx; dstsize)
1417
end
18+
19+
1520
## TO MOVE TO NNlib ######################################################
1621

1722

src/GNNGraphs/transform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr=+)
109109
idxs .= 1:num_edges
110110
idxs .= idxs .- cumsum(.!mask)
111111
num_edges = length(s)
112-
w = _scatter(aggr, w, idxs)
113-
edata = _scatter(aggr, edata, idxs)
112+
w = _scatter(aggr, w, idxs, num_edges)
113+
edata = _scatter(aggr, edata, idxs, num_edges)
114114
end
115115

116116
GNNGraph((s, t, w),

src/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ where it comes after [`apply_edges`](@ref).
134134
"""
135135
function aggregate_neighbors(g::GNNGraph, aggr, m)
136136
s, t = edge_index(g)
137-
return GNNGraphs._scatter(aggr, m, t; dstsize=g.num_nodes)
137+
return GNNGraphs._scatter(aggr, m, t, g.num_nodes)
138138
end
139139

140140

0 commit comments

Comments
 (0)