Skip to content

Commit 18d327d

Browse files
add to_bidirected (#116)
* add to_bidirected(g) * docs
1 parent 61b8c12 commit 18d327d

File tree

3 files changed

+101
-6
lines changed

3 files changed

+101
-6
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import LearnBase
1313
import StatsBase
1414
import KrylovKit
1515
using ChainRulesCore
16-
using LinearAlgebra, Random
16+
using LinearAlgebra, Random, Statistics
1717

1818
include("gnngraph.jl")
1919
export GNNGraph,
@@ -48,6 +48,7 @@ export add_nodes,
4848
remove_self_loops,
4949
remove_multi_edges,
5050
set_edge_weight,
51+
to_bidirected,
5152
# from Flux
5253
batch,
5354
unbatch,

src/GNNGraphs/transform.jl

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Remove multiple edges (also called parallel edges or repeated edges) from graph
7979
Possible edge features are aggregated according to `aggr`, that can take value
8080
`+`,`min`, `max` or `mean`.
8181
82-
See also [`remove_self_loops`](@ref).
82+
See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref).
8383
"""
8484
function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr=+)
8585
s, t = edge_index(g)
@@ -119,14 +119,14 @@ Add to graph `g` the edges with source nodes `s` and target nodes `t`.
119119
Optionally, pass the features `edata` for the new edges.
120120
"""
121121
function add_edges(g::GNNGraph{<:COO_T},
122-
snew::AbstractVector{<:Integer},
123-
tnew::AbstractVector{<:Integer};
124-
edata=nothing)
122+
snew::AbstractVector{<:Integer},
123+
tnew::AbstractVector{<:Integer};
124+
edata=nothing)
125125

126126
@assert length(snew) == length(tnew)
127127
# TODO remove this constraint
128128
@assert get_edge_weight(g) === nothing
129-
129+
130130
edata = normalize_graphdata(edata, default_name=:e, n=length(snew))
131131
edata = cat_features(g.edata, edata)
132132

@@ -157,6 +157,77 @@ end
157157
# return true
158158
# end
159159

160+
"""
161+
to_bidirected(g)
162+
163+
Adds a reverse edge for each edge in the graph, then calls
164+
[`remove_multi_edges`](@ref) with `mean` aggregation to simplify the graph.
165+
166+
See also [`is_bidirected`](@ref).
167+
168+
# Examples
169+
170+
```juliarepl
171+
julia> s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4];
172+
173+
julia> w = [1.0, 2.0, 3.0, 4.0, 5.0];
174+
175+
julia> e = [10.0, 20.0, 30.0, 40.0, 50.0];
176+
177+
julia> g = GNNGraph(s, t, w, edata = e)
178+
GNNGraph:
179+
num_nodes = 4
180+
num_edges = 5
181+
edata:
182+
e => (5,)
183+
184+
julia> g2 = to_bidirected(g)
185+
GNNGraph:
186+
num_nodes = 4
187+
num_edges = 7
188+
edata:
189+
e => (7,)
190+
191+
julia> edge_index(g2)
192+
([1, 2, 2, 3, 3, 4, 4], [2, 1, 3, 2, 4, 3, 4])
193+
194+
julia> get_edge_weight(g2)
195+
7-element Vector{Float64}:
196+
1.0
197+
1.0
198+
2.0
199+
2.0
200+
3.5
201+
3.5
202+
5.0
203+
204+
julia> g2.edata.e
205+
7-element Vector{Float64}:
206+
10.0
207+
10.0
208+
20.0
209+
20.0
210+
35.0
211+
35.0
212+
50.0
213+
```
214+
"""
215+
function to_bidirected(g::GNNGraph{<:COO_T})
216+
s, t = edge_index(g)
217+
w = get_edge_weight(g)
218+
snew = [s; t]
219+
tnew = [t; s]
220+
w = cat_features(w, w)
221+
edata = cat_features(g.edata, g.edata)
222+
223+
g = GNNGraph((snew, tnew, w),
224+
g.num_nodes, length(snew), g.num_graphs,
225+
g.graph_indicator,
226+
g.ndata, edata, g.gdata)
227+
228+
return remove_multi_edges(g; aggr=mean)
229+
end
230+
160231

161232
"""
162233
add_nodes(g::GNNGraph, n; [ndata])

test/GNNGraphs/transform.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,27 @@
209209
gw2 = set_edge_weight(g2, w)
210210
@test get_edge_weight(gw2) == w
211211
end
212+
213+
@testset "to_bidirected" begin
214+
if GRAPH_T == :coo
215+
s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4]
216+
w = [1.0, 2.0, 3.0, 4.0, 5.0]
217+
e = [10.0, 20.0, 30.0, 40.0, 50.0]
218+
g = GNNGraph(s, t, w, edata = e)
219+
220+
g2 = to_bidirected(g)
221+
@test g2.num_nodes == g.num_nodes
222+
@test g2.num_edges == 7
223+
s2, t2 = edge_index(g2)
224+
@test is_bidirected(g2)
225+
@test !has_multi_edges(g2)
226+
227+
s2, t2 = edge_index(g2)
228+
w2 = get_edge_weight(g2)
229+
@test s2 == [1, 2, 2, 3, 3, 4, 4]
230+
@test t2 == [2, 1, 3, 2, 4, 3, 4]
231+
@test w2 == [1, 1, 2, 2, 3.5, 3.5, 5]
232+
@test g2.edata.e == [10.0, 10.0, 20.0, 20.0, 35.0, 35.0, 50.0]
233+
end
234+
end
212235
end

0 commit comments

Comments
 (0)