Skip to content

Commit 91ddd53

Browse files
authored
feat: support weights when generating from SimpleWeightedGraph (#371)
* add simpleweightedgraph support * add test * add SimpleWeightedGraphs to runtests.jl * replace import for using as it's not python * change to extension * remove SimpleWeightedGraphs from deps * add PR review suggestions * add test * refinement * add PR review suggestions
1 parent 19af4ec commit 91ddd53

File tree

5 files changed

+31
-2
lines changed

5 files changed

+31
-2
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2424

2525
[weakdeps]
2626
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
27+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
2728

2829
[extensions]
2930
GraphNeuralNetworksCUDAExt = "CUDA"
31+
GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs"
3032

3133
[compat]
3234
Adapt = "3, 4"
@@ -45,6 +47,7 @@ NNlib = "0.9"
4547
NearestNeighbors = "0.4"
4648
Random = "1"
4749
Reexport = "1"
50+
SimpleWeightedGraphs = "1.4.0"
4851
SparseArrays = "1"
4952
Statistics = "1"
5053
StatsBase = "0.34"
@@ -59,9 +62,10 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
5962
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
6063
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
6164
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
65+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
6266
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6367
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6468
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
6569

6670
[targets]
67-
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]
71+
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module GraphNeuralNetworksSimpleWeightedGraphsExt
2+
3+
using GraphNeuralNetworks
4+
using Graphs
5+
using SimpleWeightedGraphs
6+
7+
function GraphNeuralNetworks.GNNGraph(g::T; kws...) where
8+
{T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}}
9+
return GNNGraph(g.weights, kws...)
10+
end
11+
12+
end #module

test/GNNGraphs/generate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,4 @@ end
119119
R = 10
120120
tg1 = rand_temporal_hyperbolic_graph(number_nodes, number_snapshots; α, R, speed, ζ)
121121
@test mean(mean(degree.(tg1.snapshots)))<=mean(mean(degree.(tg.snapshots)))
122-
end
122+
end
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@testset "simple_weighted_graph" begin
2+
srcs = [1, 2, 1]
3+
dsts = [2, 3, 3]
4+
wts = [0.5, 0.8, 2.0]
5+
g = SimpleWeightedGraph(srcs, dsts, wts)
6+
gd = SimpleWeightedDiGraph(srcs, dsts, wts)
7+
gnn_g = GNNGraph(g)
8+
gnn_gd = GNNGraph(gd)
9+
@test get_edge_weight(gnn_g) == [0.5, 2, 0.5, 0.8, 2.0, 0.8]
10+
@test get_edge_weight(gnn_gd) == [0.5, 2, 0.8]
11+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using Zygote
1515
using Test
1616
using MLDatasets
1717
using InlineStrings # not used but with the import we test #98 and #104
18+
using SimpleWeightedGraphs
1819

1920
CUDA.allowscalar(false)
2021

@@ -46,6 +47,7 @@ tests = [
4647
"mldatasets",
4748
"examples/node_classification_cora",
4849
"deprecations",
50+
"ext/GraphNeuralNetworksSimpleWeightedGraphsExt/GraphNeuralNetworksSimpleWeightedGraphsExt"
4951
]
5052

5153
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")

0 commit comments

Comments
 (0)