Skip to content

Commit 4ef3014

Browse files
ADJMAT_T can store edge weights
1 parent 48dfed3 commit 4ef3014

File tree

4 files changed

+20
-13
lines changed

4 files changed

+20
-13
lines changed

src/GNNGraphs/convert.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,18 @@ function to_dense(adj_list::ADJLIST_T, T::DataType=Int; dir=:out, num_nodes=noth
9696
A, num_nodes, num_edges
9797
end
9898

99-
function to_dense(coo::COO_T, T::DataType=Int; dir=:out, num_nodes=nothing)
99+
function to_dense(coo::COO_T, T::DataType=Nothing; dir=:out, num_nodes=nothing)
100100
# `dir` will be ignored since the input `coo` is always in source -> target format.
101101
# The output will always be a adjmat in :out format (e.g. A[i,j] denotes from i to j)
102102
s, t, val = coo
103103
n = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
104+
val = isnothing(val) ? eltype(s)(1) : val
105+
T = T == Nothing ? eltype(val) : T
104106
A = fill!(similar(s, T, (n, n)), 0)
105-
if isnothing(val)
106-
A[s .+ n .* (t .- 1)] .= 1 # exploiting linear indexing
107-
else
108-
A[s .+ n .* (t .- 1)] .= val # exploiting linear indexing
109-
end
107+
v = vec(A)
108+
idxs = s .+ n .* (t .- 1)
109+
NNlib.scatter!(+, v, val, idxs)
110+
# A[s .+ n .* (t .- 1)] .= val # exploiting linear indexing
110111
return A, n, length(s)
111112
end
112113

src/GNNGraphs/query.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,15 @@ function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out, edge_weight=tr
126126
end
127127

128128
function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=Int; dir=:out, edge_weight=true)
129-
@assert edge_weight === true
130-
@assert dir (:in, :out)
129+
@assert !(edge_weight isa AbstractArray) "passing the edge weights is not support by adjacency matrix representations"
130+
@assert dir (:in, :out, :both)
131131
A = adjacency_matrix(g, T)
132-
return dir == :out ? vec(sum(A, dims=2)) : vec(sum(A, dims=1))
132+
if (edge_weight === false) || (edge_weight === nothing)
133+
A = map(>(0), A)
134+
end
135+
return dir == :out ? vec(sum(A, dims=2)) :
136+
dir == :in ? vec(sum(A, dims=1)) :
137+
vec(sum(A, dims=1)) .+ vec(sum(A, dims=2))
133138
end
134139

135140
function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=Int; dir::Symbol=:out)

test/GNNGraphs/query.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
@testset "degree" begin
3737
s = [1, 1, 2, 3]
3838
t = [2, 2, 2, 4]
39-
eweight = [0.1, 2.1, 1.2, 1]
4039
g = GNNGraph(s, t, graph_type=GRAPH_T)
4140

4241
@test degree(g) == degree(g; dir=:out) == [2, 1, 1, 0] # default is outdegree
@@ -45,14 +44,14 @@
4544
@test eltype(degree(g, Float32)) == Float32
4645

4746
# weighted degree
48-
if GRAPH_T == :coo
47+
# if GRAPH_T == :coo
4948
eweight = [0.1, 2.1, 1.2, 1]
5049
g = GNNGraph((s, t, eweight), graph_type=GRAPH_T)
5150
@test degree(g) == [2.2, 1.2, 1.0, 0.0]
5251
@test degree(g, edge_weight=false) == [2, 1, 1, 0]
5352
@test degree(g, edge_weight=nothing) == [2, 1, 1, 0]
5453
@test degree(g, edge_weight=2*eweight) == [4.4, 2.4, 2.0, 0.0]
55-
end
54+
# end
5655

5756
if TEST_GPU
5857
d = degree(g)

test/GNNGraphs/transform.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,13 @@
131131

132132
@testset "negative_sample" begin
133133
if GRAPH_T == :coo
134-
n, m = 10,30
134+
n, m = 10, 30
135135
g = rand_graph(n, m, bidirected=true, graph_type=GRAPH_T)
136136

137137
# check bidirected=is_bidirected(g) default
138138
gneg = negative_sample(g, num_neg_edges=20)
139+
@test gneg.num_nodes == g.num_nodes
140+
@test gneg.num_edges == 20
139141
@test is_bidirected(gneg)
140142
@test intersect(g, gneg).num_edges == 0
141143
end

0 commit comments

Comments
 (0)