Skip to content

Commit 194e8f8

Browse files
implement unidirected
1 parent f8c7712 commit 194e8f8

File tree

5 files changed

+102
-3
lines changed

5 files changed

+102
-3
lines changed

docs/Project.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4+
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
45
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
56
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
69
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
10+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
11+
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
12+
PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad"
13+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
714
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
15+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
16+
WGLMakie = "276b4fcb-3e11-5398-bf8b-a0c2d153d008"

docs/make.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,52 @@
11
using Flux, NNlib, GraphNeuralNetworks, Graphs, SparseArrays
22
using Documenter
3+
using Pluto, PlutoStaticHTML
4+
5+
TutorialMenu = Array{Pair{String,String},1}()
6+
7+
#
8+
# Generate Pluto Tutorial HTMLs
9+
10+
# First tutorial with AD
11+
pluto_src_folder = joinpath(@__DIR__, "tutorials/")
12+
pluto_output_folder = joinpath(@__DIR__, "tutorials/")
13+
pluto_relative_path = "tutorials/"
14+
mkpath(pluto_output_folder)
15+
#
16+
#
17+
# Please do not use the same name as for a(n old) literate Tutorial
18+
pluto_files = [
19+
"gnn_intro.pluto",
20+
"graph_classification.pluto",
21+
]
22+
pluto_titles = [
23+
"Intro to Graph Neural Networks ",
24+
"Graph Classification",
25+
]
26+
27+
# build menu and write files myself - tp set edit url correctly.
28+
for (title, file) in zip(pluto_titles, pluto_files)
29+
global TutorialMenu
30+
rendered = build_notebooks( #though not really parallel here
31+
BuildOptions(
32+
pluto_src_folder;
33+
output_format=documenter_output,
34+
write_files=false,
35+
use_distributed=false,
36+
),
37+
["$(file).jl"],
38+
)
39+
write(
40+
pluto_output_folder * file * ".md",
41+
"""
42+
```@meta
43+
EditURL = "$(pluto_src_folder)$(file).jl"
44+
```
45+
$(rendered[1])
46+
""",
47+
)
48+
push!(TutorialMenu, title => joinpath(pluto_relative_path, file * ".md"))
49+
end
350

451
DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup,
552
:(using GraphNeuralNetworks, Graphs, SparseArrays, NNlib, Flux);

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ export add_nodes,
5050
remove_multi_edges,
5151
set_edge_weight,
5252
to_bidirected,
53+
to_unidirected,
5354
# from Flux
5455
batch,
5556
unbatch,

src/GNNGraphs/transform.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,27 @@ function to_bidirected(g::GNNGraph{<:COO_T})
235235
return remove_multi_edges(g; aggr=mean)
236236
end
237237

238+
"""
239+
to_unidirected(g::GNNGraph)
240+
241+
Return a graph that for each multiple edge between two nodes in `g`
242+
keeps only an edge in one direction.
243+
"""
244+
function to_unidirected(g::GNNGraph{<:COO_T})
245+
s, t = edge_index(g)
246+
w = get_edge_weight(g)
247+
idxs, _ = edge_encoding(s, t, g.num_nodes, directed=false)
248+
snew, tnew = edge_decoding(idxs, g.num_nodes, directed=false)
249+
250+
g = GNNGraph((snew, tnew, w),
251+
g.num_nodes, g.num_edges, g.num_graphs,
252+
g.graph_indicator,
253+
g.ndata, g.edata, g.gdata)
254+
255+
return remove_multi_edges(g; aggr=mean)
256+
end
257+
258+
238259

239260
"""
240261
add_nodes(g::GNNGraph, n; [ndata])

test/GNNGraphs/transform.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,16 +225,37 @@
225225
g2 = to_bidirected(g)
226226
@test g2.num_nodes == g.num_nodes
227227
@test g2.num_edges == 7
228-
s2, t2 = edge_index(g2)
229228
@test is_bidirected(g2)
230229
@test !has_multi_edges(g2)
231-
232-
s2, t2 = edge_index(g2)
230+
231+
s2, t2 = edge_index(g2)
233232
w2 = get_edge_weight(g2)
234233
@test s2 == [1, 2, 2, 3, 3, 4, 4]
235234
@test t2 == [2, 1, 3, 2, 4, 3, 4]
236235
@test w2 == [1, 1, 2, 2, 3.5, 3.5, 5]
237236
@test g2.edata.e == [10.0, 10.0, 20.0, 20.0, 35.0, 35.0, 50.0]
238237
end
239238
end
239+
240+
@testset "to_unidirected" begin
241+
if GRAPH_T == :coo
242+
s = [1, 2, 3, 4, 4],
243+
t = [2, 3, 4, 3, 4]
244+
w = [1.0, 2.0, 3.0, 4.0, 5.0]
245+
e = [10.0, 20.0, 30.0, 40.0, 50.0]
246+
g = GNNGraph(s, t, w, edata = e)
247+
248+
g2 = to_unidirected(g)
249+
@test g2.num_nodes == g.num_nodes
250+
@test g2.num_edges == 4
251+
@test !has_multi_edges(g2)
252+
253+
s2, t2 = edge_index(g2)
254+
w2 = get_edge_weight(g2)
255+
@test s2 == [1, 2, 3, 4]
256+
@test t2 == [2, 3, 4, 4]
257+
@test w2 == [1, 2, 3.5, 5]
258+
@test g2.edata.e == [10.0, 20.0, 35.0, 50.0]
259+
end
260+
end
240261
end

0 commit comments

Comments
 (0)