Skip to content

Commit 156cfae

Browse files
Implemented khop_adj (#239)
* Add khop_adj function * Add tests khop_adj function * Add export khop_adj * Align with adjacency_matrix signature * Add more tests * Cleaner inputs Co-authored-by: Carlo Lucibello <[email protected]> Co-authored-by: Carlo Lucibello <[email protected]>
1 parent eae9575 commit 156cfae

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ export adjacency_list,
4141
has_self_loops,
4242
has_isolated_nodes,
4343
inneighbors,
44-
outneighbors
44+
outneighbors,
45+
khop_adj
4546

4647
include("transform.jl")
4748
export add_nodes,

src/GNNGraphs/query.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,17 @@ function has_multi_edges(g::GNNGraph)
404404
length(union(idxs)) < length(idxs)
405405
end
406406

407+
"""
408+
khop_adj(g::GNNGraph,k::Int,T::DataType=eltype(g); dir=:out, weighted=true)
409+
410+
Return ``A^k`` where ``A`` is the adjacency matrix of the graph 'g'.
411+
412+
"""
413+
function khop_adj(g::GNNGraph,k::Int, T::DataType=eltype(g); dir=:out, weighted=true)
414+
return (adjacency_matrix(g, T; dir, weighted))^k
415+
end
416+
417+
407418
@non_differentiable edge_index(x...)
408419
@non_differentiable adjacency_list(x...)
409420
@non_differentiable graph_indicator(x...)

test/GNNGraphs/query.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,5 +164,18 @@
164164

165165
@test gw == [1,1,1]
166166
end
167+
168+
@testset "khop_adj" begin
169+
s = [1,2,3]
170+
t = [2,3,1]
171+
w = [0.1,0.1,0.2]
172+
g = GNNGraph(s, t, w)
173+
@test khop_adj(g,2)== adjacency_matrix(g)*adjacency_matrix(g)
174+
@test khop_adj(g,2,Int8;weighted=false) == sparse([0 0 1;1 0 0;0 1 0])
175+
@test khop_adj(g,2,Int8;dir=in,weighted=false) == sparse([0 0 1;1 0 0;0 1 0]')
176+
@test khop_adj(g,1) == adjacency_matrix(g)
177+
@test eltype(khop_adj(g,4)) == Float64
178+
@test eltype(khop_adj(g,10,Float32)) == Float32
179+
end
167180
end
168181
end

0 commit comments

Comments
 (0)