Skip to content

Commit 51e7894

Browse files
intersect
1 parent fe26097 commit 51e7894

File tree

9 files changed

+68
-4
lines changed

9 files changed

+68
-4
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ export add_nodes,
5454
include("generate.jl")
5555
export rand_graph
5656

57+
include("operators.jl")
58+
# Base.intersect
59+
5760
include("convert.jl")
5861
include("utils.jl")
5962

src/GNNGraphs/convert.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ function to_coo(coo::COO_T; dir=:out, num_nodes=nothing)
55
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
66
@assert isnothing(val) || length(val) == length(s)
77
@assert length(s) == length(t)
8-
@assert min(minimum(s), minimum(t)) >= 1
9-
@assert max(maximum(s), maximum(t)) <= num_nodes
10-
8+
if !isempty(s)
9+
@assert min(minimum(s), minimum(t)) >= 1
10+
@assert max(maximum(s), maximum(t)) <= num_nodes
11+
end
1112
num_edges = length(s)
1213
return coo, num_nodes, num_edges
1314
end

src/GNNGraphs/gnngraph.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ function GNNGraph(data;
150150
ndata, edata, gdata)
151151
end
152152

153+
function GNNGraph(n::T; graph_type=:coo, kws...) where {T<:Integer}
154+
s, t = T[], T[]
155+
return GNNGraph(s, t; graph_type, num_nodes=n, kws...)
156+
end
157+
153158
# COO convenience constructors
154159
GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) = GNNGraph((s, t, v); kws...)
155160
GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)

src/GNNGraphs/operators.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# 2 or more args graph operators
2+
function Base.intersect(g1::GNNGraph, g2::GNNGraph)
3+
@assert g1.num_nodes == g2.num_nodes
4+
@assert graph_type_symbol(g1) == graph_type_symbol(g2)
5+
graph_type = graph_type_symbol(g1)
6+
num_nodes = g1.num_nodes
7+
8+
idx1, _ = edge_encoding(edge_index(g1)..., num_nodes)
9+
idx2, _ = edge_encoding(edge_index(g2)..., num_nodes)
10+
idx = intersect(idx1, idx2)
11+
s, t = edge_decoding(idx, num_nodes)
12+
return GNNGraph(s, t; num_nodes, graph_type)
13+
end

src/GNNGraphs/query.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ end
2828

2929
Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i,j] != 0
3030

31+
graph_type_symbol(g::GNNGraph{<:COO_T}) = :coo
32+
graph_type_symbol(g::GNNGraph{<:SPARSE_T}) = :sparse
33+
graph_type_symbol(g::GNNGraph{<:ADJMAT_T}) = :dense
34+
3135
Graphs.nv(g::GNNGraph) = g.num_nodes
3236
Graphs.ne(g::GNNGraph) = g.num_edges
3337
Graphs.has_vertex(g::GNNGraph, i::Int) = 1 <= i <= g.num_nodes

test/GNNGraphs/operators.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
@testset "Operators" begin
2+
@testset "intersect" begin
3+
g = rand_graph(10, 20, graph_type=GRAPH_T)
4+
@test intersect(g, g).num_edges == 20
5+
end
6+
end

test/GNNGraphs/transform.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,16 @@
128128
@test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g))
129129
end
130130
end
131+
132+
@testset "negative_sample" begin
133+
if GRAPH_T == :coo
134+
n, m = 10,30
135+
g = rand_graph(n, m, bidirected=true, graph_type=GRAPH_T)
136+
137+
# check bidirected=is_bidirected(g) default
138+
gneg = negative_sample(g, num_neg_edges=20)
139+
@test is_bidirected(gneg)
140+
@test intersect(g, gneg).num_edges == 0
141+
end
142+
end
131143
end

test/GNNGraphs/utils.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
@test sdec == s
1515
@test tdec == t
1616

17+
n1, m1 = 10, 30
18+
g = rand_graph(n1, m1)
19+
s1, t1 = edge_index(g)
20+
idx, maxid = GNNGraphs.edge_encoding(s1, t1, n1)
21+
sdec, tdec = GNNGraphs.edge_decoding(idx, n1)
22+
@test sdec == s1
23+
@test tdec == t1
24+
1725
# directed=false
1826
idx, maxid = GNNGraphs.edge_encoding(s, t, n, directed=false)
1927
@test maxid == n*(n+1)÷2
@@ -28,6 +36,17 @@
2836
@test sdec == snew
2937
@test tdec == tnew
3038

31-
g = rand_graph(10, 30, bidirected=true)
39+
n1, m1 = 6, 8
40+
g = rand_graph(n1, m1)
41+
s1, t1 = edge_index(g)
42+
idx, maxid = GNNGraphs.edge_encoding(s1, t1, n1, directed=false)
43+
sdec, tdec = GNNGraphs.edge_decoding(idx, n1, directed=false)
44+
mask = s1 .> t1
45+
snew = copy(s1)
46+
tnew = copy(t1)
47+
snew[mask] .= t1[mask]
48+
tnew[mask] .= s1[mask]
49+
@test sdec == snew
50+
@test tdec == tnew
3251
end
3352
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ include("test_utils.jl")
2323
tests = [
2424
"GNNGraphs/gnngraph",
2525
"GNNGraphs/transform",
26+
"GNNGraphs/operators",
2627
"GNNGraphs/generate",
2728
"GNNGraphs/query",
2829
"utils",

0 commit comments

Comments
 (0)