Skip to content

Commit 55fe50b

Browse files
get_graph_type
1 parent ebab567 commit 55fe50b

File tree

5 files changed

+81
-6
lines changed

5 files changed

+81
-6
lines changed

GNNGraphs/src/GNNGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ include("query.jl")
4747
export adjacency_list,
4848
edge_index,
4949
get_edge_weight,
50+
get_graph_type,
5051
graph_indicator,
5152
has_multi_edges,
5253
is_directed,

GNNGraphs/src/operators.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ Intersect two graphs by keeping only the common edges.
66
"""
77
function Base.intersect(g1::GNNGraph, g2::GNNGraph)
88
@assert g1.num_nodes == g2.num_nodes
9-
@assert graph_type_symbol(g1) == graph_type_symbol(g2)
10-
graph_type = graph_type_symbol(g1)
9+
@assert get_graph_type(g1) == get_graph_type(g2)
10+
graph_type = get_graph_type(g1)
1111
num_nodes = g1.num_nodes
1212

1313
idx1, _ = edge_encoding(edge_index(g1)..., num_nodes)

GNNGraphs/src/query.jl

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,61 @@ function Graphs.has_edge(g::GNNHeteroGraph, edge_t::EType, i::Integer, j::Intege
8080
return any((s .== i) .& (t .== j))
8181
end
8282

83-
graph_type_symbol(::GNNGraph{<:COO_T}) = :coo
84-
graph_type_symbol(::GNNGraph{<:SPARSE_T}) = :sparse
85-
graph_type_symbol(::GNNGraph{<:ADJMAT_T}) = :dense
83+
"""
84+
get_graph_type(g::GNNGraph)
85+
86+
Return the underlying representation for the graph `g` as a symbol.
87+
88+
Possible values are:
89+
- `:coo`: Coordinate list representation. The graph is stored as a tuple of vectors `(s, t, w)`,
90+
where `s` and `t` are the source and target nodes of the edges, and `w` is the edge weights.
91+
- `:sparse`: Sparse matrix representation. The graph is stored as a sparse matrix representing the weighted adjacency matrix.
92+
- `:dense`: Dense matrix representation. The graph is stored as a dense matrix representing the weighted adjacency matrix.
93+
94+
The default representation for graph constructors GNNGraphs.jl is `:coo`.
95+
The underlying representation can be accessed through the `g.graph` field.
96+
97+
See also [`GNNGraph`](@ref).
98+
99+
# Examples
100+
101+
The default representation for graph constructors GNNGraphs.jl is `:coo`.
102+
```jldoctest
103+
julia> g = rand_graph(5, 10)
104+
GNNGraph:
105+
num_nodes: 5
106+
num_edges: 10
107+
108+
julia> get_graph_type(g)
109+
:coo
110+
```
111+
The `GNNGraph` constructor can also be used to create graphs with different representations.
112+
```jldoctest
113+
julia> g = GNNGraph([2,3,5], [1,2,4], graph_type=:sparse)
114+
GNNGraph:
115+
num_nodes: 5
116+
num_edges: 3
117+
118+
julia> g.graph
119+
5×5 SparseArrays.SparseMatrixCSC{Int64, Int64} with 3 stored entries:
120+
⋅ ⋅ ⋅ ⋅ ⋅
121+
1 ⋅ ⋅ ⋅ ⋅
122+
⋅ 1 ⋅ ⋅ ⋅
123+
⋅ ⋅ ⋅ ⋅ ⋅
124+
⋅ ⋅ ⋅ 1 ⋅
125+
126+
julia> get_graph_type(g)
127+
:sparse
128+
129+
julia> gcoo = GNNGraph(g, graph_type=:coo);
130+
131+
julia> gcoo.graph
132+
([2, 3, 5], [1, 2, 4], [1, 1, 1])
133+
```
134+
"""
135+
get_graph_type(::GNNGraph{<:COO_T}) = :coo
136+
get_graph_type(::GNNGraph{<:SPARSE_T}) = :sparse
137+
get_graph_type(::GNNGraph{<:ADJMAT_T}) = :dense
86138

87139
Graphs.nv(g::GNNGraph) = g.num_nodes
88140
Graphs.ne(g::GNNGraph) = g.num_edges

GNNGraphs/test/query.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,20 @@ if GRAPH_T == :coo
257257
end
258258
end
259259

260+
@testset "get_graph_type" begin
261+
g = rand_graph(10, 20, graph_type = GRAPH_T)
262+
@test get_graph_type(g) == GRAPH_T
263+
264+
gsparse = GNNGraph(g, graph_type=:sparse)
265+
@test get_graph_type(gsparse) == :sparse
266+
@test gsparse.graph isa SparseMatrixCSC
267+
268+
gcoo = GNNGraph(g, graph_type=:coo)
269+
@test get_graph_type(gcoo) == :coo
270+
@test gcoo.graph[1:2] isa Tuple{Vector{Int}, Vector{Int}}
271+
272+
273+
gdense = GNNGraph(g, graph_type=:dense)
274+
@test get_graph_type(gdense) == :dense
275+
@test gdense.graph isa Matrix{Int}
276+
end

GraphNeuralNetworks/test/test_module.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,20 @@ function finitediff_withgradient(f, x...)
5959
end
6060

6161
function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
62+
equal = true
6263
fmapstructure_with_path(a, b) do kp, x, y
6364
if x isa AbstractArray
6465
# @show kp
65-
@assert x y rtol=rtol atol=atol
66+
# @assert x ≈ y rtol=rtol atol=atol
67+
if !isapprox(x, y; rtol, atol)
68+
equal = false
69+
end
6670
# elseif x isa Number
6771
# @show kp
6872
# @assert x ≈ y rtol=rtol atol=atol
6973
end
7074
end
75+
@assert equal
7176
end
7277

7378
function test_gradients(

0 commit comments

Comments
 (0)