Skip to content

Commit 61b8c12

Browse files
support for aggregating edge feature in remove_multi_edges (#115)
* work on remove multi edges * improve remove_self_loops
1 parent 9a9370c commit 61b8c12

File tree

5 files changed

+77
-36
lines changed

5 files changed

+77
-36
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,7 @@ include("operators.jl")
6767
include("convert.jl")
6868
include("utils.jl")
6969

70-
70+
include("gatherscatter.jl")
71+
# _gather, _scatter
72+
7173
end #module

src/GNNGraphs/gatherscatter.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_gather(x::NamedTuple, i) = map(x -> _gather(x, i), x)
2+
_gather(x::Tuple, i) = map(x -> _gather(x, i), x)
3+
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
4+
_gather(x::Nothing, i) = nothing
5+
6+
_scatter(aggr, m::NamedTuple, t; dstsize=nothing) = map(m -> _scatter(aggr, m, t; dstsize), m)
7+
_scatter(aggr, m::Tuple, t; dstsize=nothing) = map(m -> _scatter(aggr, m, t; dstsize), m)
8+
_scatter(aggr, m::AbstractArray, t; dstsize=nothing) = NNlib.scatter(aggr, m, t; dstsize)
9+
_scatter(aggr, m::Nothing, t; dstsize=nothing) = nothing

src/GNNGraphs/transform.jl

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,19 @@ end
4141

4242
function remove_self_loops(g::GNNGraph{<:COO_T})
4343
s, t = edge_index(g)
44-
# TODO remove these constraints
45-
@assert g.edata === (;)
46-
@assert get_edge_weight(g) === nothing
44+
w = get_edge_weight(g)
45+
edata = g.edata
4746

4847
mask_old_loops = s .!= t
4948
s = s[mask_old_loops]
5049
t = t[mask_old_loops]
51-
52-
GNNGraph((s, t, nothing),
50+
edata = getobs(edata, mask_old_loops)
51+
w = isnothing(w) ? nothing : getobs(w, mask_old_loops)
52+
53+
GNNGraph((s, t, w),
5354
g.num_nodes, length(s), g.num_graphs,
5455
g.graph_indicator,
55-
g.ndata, g.edata, g.gdata)
56+
g.ndata, edata, g.gdata)
5657
end
5758

5859

@@ -72,25 +73,43 @@ end
7273

7374

7475
"""
75-
remove_multi_edges(g::GNNGraph)
76+
remove_multi_edges(g::GNNGraph; aggr=+)
7677
7778
Remove multiple edges (also called parallel edges or repeated edges) from graph `g`.
79+
Possible edge features are aggregated according to `aggr`, that can take value
80+
`+`,`min`, `max` or `mean`.
81+
82+
See also [`remove_self_loops`](@ref).
7883
"""
79-
function remove_multi_edges(g::GNNGraph{<:COO_T})
84+
function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr=+)
8085
s, t = edge_index(g)
81-
# TODO remove these constraints
82-
@assert g.num_graphs == 1
83-
@assert g.edata === (;)
84-
@assert get_edge_weight(g) === nothing
86+
w = get_edge_weight(g)
87+
edata = g.edata
88+
num_edges = g.num_edges
8589

8690
idxs, idxmax = edge_encoding(s, t, g.num_nodes)
87-
union!(idxs)
88-
s, t = edge_decoding(idxs, g.num_nodes)
91+
92+
perm = sortperm(idxs)
93+
idxs = idxs[perm]
94+
s, t = s[perm], t[perm]
95+
edata = getobs(edata, perm)
96+
w = isnothing(w) ? nothing : getobs(w, perm)
97+
idxs = [-1; idxs]
98+
mask = idxs[2:end] .> idxs[1:end-1]
99+
if !all(mask)
100+
s, t = s[mask], t[mask]
101+
idxs = similar(s, num_edges)
102+
idxs .= 1:num_edges
103+
idxs .= idxs .- cumsum(.!mask)
104+
num_edges = length(s)
105+
w = _scatter(aggr, w, idxs)
106+
edata = _scatter(aggr, edata, idxs)
107+
end
89108

90-
GNNGraph((s, t, nothing),
91-
g.num_nodes, length(s), g.num_graphs,
109+
GNNGraph((s, t, w),
110+
g.num_nodes, num_edges, g.num_graphs,
92111
g.graph_indicator,
93-
g.ndata, g.edata, g.gdata)
112+
g.ndata, edata, g.gdata)
94113
end
95114

96115
"""

src/msgpass.jl

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,12 @@ apply_edges(l, g::GNNGraph; xi=nothing, xj=nothing, e=nothing) =
112112

113113
function apply_edges(f, g::GNNGraph, xi, xj, e)
114114
s, t = edge_index(g)
115-
xi = _gather(xi, t) # size: (D, num_nodes) -> (D, num_edges)
116-
xj = _gather(xj, s)
115+
xi = GNNGraphs._gather(xi, t) # size: (D, num_nodes) -> (D, num_edges)
116+
xj = GNNGraphs._gather(xj, s)
117117
m = f(xi, xj, e)
118118
return m
119119
end
120120

121-
_gather(x::NamedTuple, i) = map(x -> _gather(x, i), x)
122-
_gather(x::Tuple, i) = map(x -> _gather(x, i), x)
123-
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
124-
_gather(x::Nothing, i) = nothing
125-
126-
127121
## AGGREGATE NEIGHBORS
128122
@doc raw"""
129123
aggregate_neighbors(g::GNNGraph, aggr, m)
@@ -140,14 +134,9 @@ where it comes after [`apply_edges`](@ref).
140134
"""
141135
function aggregate_neighbors(g::GNNGraph, aggr, m)
142136
s, t = edge_index(g)
143-
return _scatter(aggr, m, t)
137+
return GNNGraphs._scatter(aggr, m, t)
144138
end
145139

146-
_scatter(aggr, m::NamedTuple, t) = map(m -> _scatter(aggr, m, t), m)
147-
_scatter(aggr, m::Tuple, t) = map(m -> _scatter(aggr, m, t), m)
148-
_scatter(aggr, m::AbstractArray, t) = NNlib.scatter(aggr, m, t)
149-
150-
151140

152141
### MESSAGE FUNCTIONS ###
153142
"""

test/GNNGraphs/transform.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,25 @@
107107
end
108108

109109
@testset "remove_self_loops" begin
110-
if GRAPH_T == :coo
110+
if GRAPH_T == :coo # add_edges and set_edge_weight only implemented for coo
111111
g = rand_graph(10, 20, graph_type=GRAPH_T)
112112
g1 = add_edges(g, [1:5;], [1:5;])
113113
@test g1.num_edges == g.num_edges + 5
114-
g2 = remove_self_loops(g)
114+
g2 = remove_self_loops(g1)
115115
@test g2.num_edges == g.num_edges
116116
@test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g))
117-
end
117+
118+
# with edge features and weights
119+
g1 = GNNGraph(g1, edata=(e1=ones(3,g1.num_edges), e2=2*ones(g1.num_edges)))
120+
g1 = set_edge_weight(g1, 3*ones(g1.num_edges))
121+
g2 = remove_self_loops(g1)
122+
@test g2.num_edges == g.num_edges
123+
@test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g))
124+
@test size(get_edge_weight(g2)) == (g2.num_edges,)
125+
@test size(g2.edata.e1) == (3, g2.num_edges)
126+
@test size(g2.edata.e2) == (g2.num_edges,)
127+
128+
end
118129
end
119130

120131
@testset "remove_multi_edges" begin
@@ -123,9 +134,20 @@
123134
s, t = edge_index(g)
124135
g1 = add_edges(g, s[1:5], t[1:5])
125136
@test g1.num_edges == g.num_edges + 5
126-
g2 = remove_multi_edges(g)
137+
g2 = remove_multi_edges(g1, aggr=+)
138+
@test g2.num_edges == g.num_edges
139+
@test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g))
140+
141+
# Default aggregation is +
142+
g1 = GNNGraph(g1, edata=(e1=ones(3,g1.num_edges), e2=2*ones(g1.num_edges)))
143+
g1 = set_edge_weight(g1, 3*ones(g1.num_edges))
144+
g2 = remove_multi_edges(g1)
127145
@test g2.num_edges == g.num_edges
128146
@test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g))
147+
@test count(g2.edata.e1[:,i] == 2*ones(3) for i in 1:g2.num_edges) == 5
148+
@test count(g2.edata.e2[i] == 4 for i in 1:g2.num_edges) == 5
149+
w2 = get_edge_weight(g2)
150+
@test count(w2[i] == 6 for i in 1:g2.num_edges) == 5
129151
end
130152
end
131153

0 commit comments

Comments
 (0)