Skip to content

Commit 9b8bbf1

Browse files
fix tests
1 parent 95a9af1 commit 9b8bbf1

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

src/GNNGraphs/query.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,15 +182,15 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=nothing; dir=:out, edge_weight
182182
vec(sum(A, dims=1)) .+ vec(sum(A, dims=2))
183183
end
184184

185-
function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=Float32; dir::Symbol=:out)
185+
function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=nodetype(g); dir::Symbol=:out)
186186
A = adjacency_matrix(g, T; dir=dir)
187187
D = Diagonal(vec(sum(A; dims=2)))
188188
return D - A
189189
end
190190

191191

192192
"""
193-
normalized_laplacian(g, T=nothing; add_self_loops=false, dir=:out)
193+
normalized_laplacian(g, T=Float32; add_self_loops=false, dir=:out)
194194
195195
Normalized Laplacian matrix of graph `g`.
196196
@@ -219,7 +219,7 @@ function normalized_adjacency(g::GNNGraph, T::DataType=Float32;
219219
end
220220

221221
@doc raw"""
222-
scaled_laplacian(g, T=nothing; dir=:out)
222+
scaled_laplacian(g, T=Float32; dir=:out)
223223
224224
Scaled Laplacian matrix of graph `g`,
225225
defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix.

test/GNNGraphs/query.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
@test Array(d_gpu) == d
5353
end
5454
end
55-
55+
5656
@testset "weighted" begin
5757
# weighted degree
5858
s = [1, 1, 2, 3]
@@ -84,4 +84,13 @@
8484
end
8585
end
8686
end
87+
88+
@testset "laplacian_matrix" begin
89+
g = rand_graph(10, 30, graph_type=GRAPH_T)
90+
A = adjacency_matrix(g)
91+
D = Diagonal(vec(sum(A, dims=2)))
92+
L = laplacian_matrix(g)
93+
@test eltype(L) == GNNGraphs.nodetype(g)
94+
@test L D - A
95+
end
8796
end

test/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
end
4141

4242
@testset "ChebConv" begin
43-
k = 3
43+
k = 2
4444
l = ChebConv(in_channel => out_channel, k)
4545
@test size(l.weight) == (out_channel, in_channel, k)
4646
@test size(l.bias) == (out_channel,)

0 commit comments

Comments
 (0)