Skip to content

Commit 1aade04

Browse files
export get_edge_weight + use weights from adjacency matrix (#86)
* When constructing a graph from an adjacency matrix, the matrix's elements are now considered edge weights. * export get_edge_weight
1 parent 8530067 commit 1aade04

File tree

10 files changed

+82
-30
lines changed

10 files changed

+82
-30
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ Graphs = "1.4"
3535
KrylovKit = "0.5"
3636
LearnBase = "0.4, 0.5"
3737
MacroTools = "0.5"
38-
NearestNeighbors = "0.4"
3938
NNlib = "0.7"
4039
NNlibCUDA = "0.1"
40+
NearestNeighbors = "0.4"
4141
Reexport = "1"
4242
StatsBase = "0.32, 0.33"
4343
julia = "1.6"

docs/src/gnngraph.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,15 @@ julia> g = GNNGraph(source, target, weight)
130130
GNNGraph:
131131
num_nodes = 3
132132
num_edges = 6
133-
133+
134+
julia> get_edge_weight(g)
135+
6-element Vector{Float64}:
136+
1.0
137+
0.5
138+
2.1
139+
2.3
140+
4.0
141+
4.1
134142
```
135143

136144
## Batches and Subgraphs

src/GNNGraphs/GNNGraphs.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ export GNNGraph,
2424

2525
include("query.jl")
2626
export adjacency_list,
27-
edge_index,
27+
edge_index,
28+
get_edge_weight,
2829
graph_indicator,
29-
has_multi_edges,
30+
has_multi_edges,
3031
is_directed,
3132
is_bidirected,
3233
normalized_laplacian,
@@ -39,16 +40,16 @@ export adjacency_list,
3940
outneighbors
4041

4142
include("transform.jl")
42-
export add_nodes,
43-
add_edges,
43+
export add_nodes,
44+
add_edges,
4445
add_self_loops,
4546
getgraph,
4647
negative_sample,
4748
rand_edge_split,
48-
remove_self_loops,
49+
remove_self_loops,
4950
remove_multi_edges,
5051
# from Flux
51-
batch,
52+
batch,
5253
unbatch,
5354
# from SparseArrays
5455
blockdiag

src/GNNGraphs/convert.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,22 @@ function to_coo(A::SPARSE_T; dir=:out, num_nodes=nothing)
1818
if dir == :in
1919
s, t = t, s
2020
end
21-
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
21+
num_nodes = isnothing(num_nodes) ? size(A, 1) : num_nodes
2222
num_edges = length(s)
2323

24-
return (s, t, nothing), num_nodes, num_edges
24+
return (s, t, v), num_nodes, num_edges
2525
end
2626

2727
function to_coo(A::ADJMAT_T; dir=:out, num_nodes=nothing)
2828
nz = findall(!=(0), A) # vec of cartesian indexes
2929
s, t = ntuple(i -> map(t->t[i], nz), 2)
30+
v = A[nz]
3031
if dir == :in
3132
s, t = t, s
3233
end
33-
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
34+
num_nodes = isnothing(num_nodes) ? size(A, 1) : num_nodes
3435
num_edges = length(s)
35-
return (s, t, nothing), num_nodes, num_edges
36+
return (s, t, v), num_nodes, num_edges
3637
end
3738

3839
function to_coo(adj_list::ADJLIST_T; dir=:out, num_nodes=nothing)
@@ -140,15 +141,17 @@ end
140141

141142
function to_sparse(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing)
142143
s, t, eweight = coo
143-
T = T === nothing ? eltype(s) : T
144-
eweight = isnothing(eweight) ? fill!(similar(s, T), 1) : eweight
144+
T = T === nothing ? (eweight === nothing ? eltype(s) : eltype(eweight)) : T
145+
eweight = eweight === nothing ? fill!(similar(s, T), 1) : eweight
145146
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
146147
A = sparse(s, t, eweight, num_nodes, num_nodes)
147-
num_edges = length(s)
148+
num_edges = nnz(A)
149+
if eltype(A) != T
150+
A = T.(A)
151+
end
148152
return A, num_nodes, num_edges
149153
end
150154

151-
152155
@non_differentiable to_coo(x...)
153156
@non_differentiable to_dense(x...)
154157
@non_differentiable to_sparse(x...)

src/GNNGraphs/gnngraph.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ functionality from that library.
4747
- An adjacency matrix
4848
- An adjacency list.
4949
- A tuple containing the source and target vectors (COO representation)
50-
- A Graphs' graph.
50+
- A Graphs.jl' graph.
5151
- `graph_type`: A keyword argument that specifies
5252
the underlying representation used by the GNNGraph.
5353
Currently supported values are
@@ -61,9 +61,9 @@ functionality from that library.
6161
Possible values are `:out` and `:in`. Default `:out`.
6262
- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`.
6363
- `graph_indicator`: For batched graphs, a vector containing the graph assigment of each node. Default `nothing`.
64-
- `ndata`: Node features. A named tuple of arrays whose last dimension has size `num_nodes`.
65-
- `edata`: Edge features. A named tuple of arrays whose last dimension has size `num_edges`.
66-
- `gdata`: Graph features. A named tuple of arrays whose last dimension has size `num_graphs`.
64+
- `ndata`: Node features. An array or named tuple of arrays whose last dimension has size `num_nodes`.
65+
- `edata`: Edge features. An array or named tuple of arrays whose last dimension has size `num_edges`.
66+
- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`.
6767
6868
# Examples
6969

src/GNNGraphs/query.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normal
230230
"""
231231
function scaled_laplacian(g::GNNGraph, T::DataType=Float32; dir=:out)
232232
L = normalized_laplacian(g, T)
233-
@assert issymmetric(L) "scaled_laplacian only works with symmetric matrices"
233+
# @assert issymmetric(L) "scaled_laplacian only works with symmetric matrices"
234234
λmax = _eigmax(L)
235235
return 2 / λmax * L - I
236236
end

src/GraphNeuralNetworks.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,21 @@ using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
2222
export
2323
# utils
2424
reduce_nodes,
25-
reduce_edges,
25+
reduce_edges,
2626
softmax_nodes,
2727
softmax_edges,
28-
broadcast_nodes,
28+
broadcast_nodes,
2929
broadcast_edges,
3030
softmax_edge_neighbors,
3131

3232
# msgpass
3333
apply_edges,
34-
aggregate_neighbors,
34+
aggregate_neighbors,
3535
propagate,
36-
copy_xj,
37-
copy_xi,
36+
copy_xj,
37+
copy_xi,
3838
xi_dot_xj,
39+
e_mul_xj,
3940

4041
# layers/basic
4142
GNNLayer,

src/layers/conv.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -832,16 +832,15 @@ paper. In the forward pass, takes as inputs node features `x` and edge features
832832
updated features `x'` and `e'` according to
833833
834834
```math
835-
\mathbf{e}_{i\to j}' = \phi_e([\mathbf{x}_i; \mathbf{x}_j; \mathbf{e}_{i\to j}])\\
836-
\mathbf{x}_{i}' = \phi_v([\mathbf{x}_i; \square_{j\in \mathcal{N}(i)\,\mathbf{e}_{j\to i}'])
835+
\mathbf{e}_{i\to j}' = \phi_e([\mathbf{x}_i; \mathbf{x}_j; \mathbf{e}_{i\to j}]),\\
836+
\mathbf{x}_{i}' = \phi_v([\mathbf{x}_i; \square_{j\in \mathcal{N}(i)\,\mathbf{e}_{j\to i}']).
837837
```
838838
839839
`aggr` defines the aggregation to be performed.
840840
841841
If the neural networks `ϕe` and `ϕv` are not provided, they will be constructed from
842842
the `in` and `out` arguments instead as multi-layer perceptron with one hidden layer and `relu`
843843
activations.
844-
````
845844
846845
# Examples
847846

test/GNNGraphs/gnngraph.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,36 @@
11
@testset "GNNGraph" begin
2+
3+
@testset "Constructor: adjacency matrix" begin
4+
A = sprand(10, 10, 0.5)
5+
sA, tA, vA = findnz(A)
6+
7+
g = GNNGraph(A, graph_type=GRAPH_T)
8+
s, t = edge_index(g)
9+
v = get_edge_weight(g)
10+
@test s == sA
11+
@test t == tA
12+
@test v == vA
13+
14+
g = GNNGraph(Matrix(A), graph_type=GRAPH_T)
15+
s, t = edge_index(g)
16+
v = get_edge_weight(g)
17+
@test s == sA
18+
@test t == tA
19+
@test v == vA
20+
21+
g = GNNGraph([0 0 0
22+
0 0 1
23+
0 1 0], graph_type=GRAPH_T)
24+
@test g.num_nodes == 3
25+
@test g.num_edges == 2
26+
27+
g = GNNGraph([0 1 0
28+
1 0 0
29+
0 0 0], graph_type=GRAPH_T)
30+
@test g.num_nodes == 3
31+
@test g.num_edges == 2
32+
end
33+
234
@testset "symmetric graph" begin
335
s = [1, 1, 2, 2, 3, 3, 4, 4]
436
t = [2, 4, 1, 3, 2, 4, 1, 3]
@@ -124,7 +156,7 @@
124156
@test adjacency_list(g, dir=:in) == adj_list_in
125157
end
126158

127-
@testset "Graphs constructor" begin
159+
@testset "Graphs.jl constructor" begin
128160
lg = random_regular_graph(10, 4)
129161
@test !Graphs.is_directed(lg)
130162
g = GNNGraph(lg)

test/GNNGraphs/query.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,12 @@
9292
@test eltype(L) == eltype(g)
9393
@test L D - A
9494
end
95+
96+
@testset "adjacency_matrix" begin
97+
a = sprand(5, 5, 0.5)
98+
g = GNNGraph(a, graph_type=GRAPH_T)
99+
A = adjacency_matrix(g, Float32)
100+
@test a A
101+
@test eltype(A) == Float32
102+
end
95103
end

0 commit comments

Comments
 (0)