Skip to content

Commit a0a7fad

Browse files
authored
Coalesce function and state for coo graphs (#613)
* Add coalesce_graph function for COO graphs and is_coalesced flag * Simplify coalesce_graph function * Revert back to old sort_edge_index function * Rename remove_multi_edges to coalesce, update references, add deprecation * Add is_coalesced function to check if graph is coalesced * Add tests to verify is_coalesced for various graph transformations
1 parent ea6c42b commit a0a7fad

File tree

6 files changed

+90
-15
lines changed

6 files changed

+90
-15
lines changed

GNNGraphs/src/GNNGraphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ export add_nodes,
7575
rand_edge_split,
7676
remove_self_loops,
7777
remove_edges,
78-
remove_multi_edges,
78+
coalesce,
7979
set_edge_weight,
8080
to_bidirected,
8181
to_unidirected,

GNNGraphs/src/deprecations.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ function Base.getproperty(vds::Vector{DataStore}, s::Symbol)
1111
return [getdata(ds)[s] for ds in vds]
1212
end
1313
end
14+
15+
@deprecate remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +) Base.coalesce(g; aggr = aggr)

GNNGraphs/src/gnngraph.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,20 @@ struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
113113
ndata::DataStore
114114
edata::DataStore
115115
gdata::DataStore
116+
is_coalesced::Bool # only for :coo, true if the graph is coalesced, i.e., indices ordered by row and no multi edges
117+
end
118+
119+
# GNNGraph constructor setting the is_coalesced field to false
120+
function GNNGraph(graph::T,
121+
num_nodes::Int,
122+
num_edges::Int,
123+
num_graphs::Int,
124+
graph_indicator::Union{Nothing, AVecI},
125+
ndata::DataStore,
126+
edata::DataStore,
127+
gdata::DataStore) where {T <: Union{COO_T, ADJMAT_T}}
128+
return GNNGraph{T}(graph, num_nodes, num_edges, num_graphs,
129+
graph_indicator, ndata, edata, gdata, false)
116130
end
117131

118132
function GNNGraph(data::D;

GNNGraphs/src/query.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@ Graphs.ne(g::GNNGraph) = g.num_edges
103103
Graphs.has_vertex(g::GNNGraph, i::Int) = 1 <= i <= g.num_nodes
104104
Graphs.vertices(g::GNNGraph) = 1:(g.num_nodes)
105105

106+
"""
107+
is_coalesced(g::GNNGraph) -> Bool
108+
109+
Check whether the given `GNNGraph` `g` is coalesced (see [`coalesce`](@ref)). Only meaningful for COO graphs.
110+
111+
# Arguments
112+
- `g::GNNGraph`: The graph to check.
113+
114+
# Returns
115+
- `Bool`: Whether the graph is coalesced. If the graph is not of type COO, this function will always return `false`.
116+
"""
117+
is_coalesced(g::GNNGraph) = g.is_coalesced
106118

107119
"""
108120
neighbors(g::GNNGraph, i::Integer; dir=:out)

GNNGraphs/src/transform.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ end
4444
Return a graph constructed from `g` where self-loops (edges from a node to itself)
4545
are removed.
4646
47-
See also [`add_self_loops`](@ref) and [`remove_multi_edges`](@ref).
47+
See also [`add_self_loops`](@ref) and [`coalesce`](@ref).
4848
"""
4949
function remove_self_loops(g::GNNGraph{<:COO_T})
5050
s, t = edge_index(g)
@@ -146,15 +146,14 @@ function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5)
146146
end
147147

148148
"""
149-
remove_multi_edges(g::GNNGraph; aggr=+)
149+
coalesce(g::GNNGraph; aggr=+)
150150
151-
Remove multiple edges (also called parallel edges or repeated edges) from graph `g`.
152-
Possible edge features are aggregated according to `aggr`, that can take value
153-
`+`,`min`, `max` or `mean`.
151+
Return a new GNNGraph where all multiple edges between the same pair of nodes are merged (using aggr for edge weights and features), and the edge indices are sorted lexicographically (by source, then target).
152+
This method is only applicable to graphs of type `:coo`.
154153
155-
See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref).
154+
`aggr` can take value `+`,`min`, `max` or `mean`.
156155
"""
157-
function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +)
156+
function Base.coalesce(g::GNNGraph{<:COO_T}; aggr = +)
158157
s, t = edge_index(g)
159158
w = get_edge_weight(g)
160159
edata = g.edata
@@ -181,7 +180,7 @@ function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +)
181180
return GNNGraph((s, t, w),
182181
g.num_nodes, num_edges, g.num_graphs,
183182
g.graph_indicator,
184-
g.ndata, edata, g.gdata)
183+
g.ndata, edata, g.gdata, true)
185184
end
186185

187186
"""
@@ -441,7 +440,7 @@ end
441440
to_bidirected(g)
442441
443442
Adds a reverse edge for each edge in the graph, then calls
444-
[`remove_multi_edges`](@ref) with `mean` aggregation to simplify the graph.
443+
[`coalesce`](@ref) with `mean` aggregation to simplify the graph.
445444
446445
See also [`is_bidirected`](@ref).
447446
@@ -505,7 +504,7 @@ function to_bidirected(g::GNNGraph{<:COO_T})
505504
g.graph_indicator,
506505
g.ndata, edata, g.gdata)
507506

508-
return remove_multi_edges(g; aggr = mean)
507+
return coalesce(g; aggr = mean)
509508
end
510509

511510
"""
@@ -525,7 +524,7 @@ function to_unidirected(g::GNNGraph{<:COO_T})
525524
g.graph_indicator,
526525
g.ndata, g.edata, g.gdata)
527526

528-
return remove_multi_edges(g; aggr = mean)
527+
return coalesce(g; aggr = mean)
529528
end
530529

531530
function Graphs.SimpleGraph(g::GNNGraph)

GNNGraphs/test/transform.jl

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,22 +345,22 @@ end
345345
end
346346
end
347347

348-
@testitem "remove_multi_edges" setup=[GraphsTestModule] begin
348+
@testitem "coalesce" setup=[GraphsTestModule] begin
349349
using .GraphsTestModule
350350
for GRAPH_T in GRAPH_TYPES
351351
if GRAPH_T == :coo
352352
g = rand_graph(10, 20, graph_type = GRAPH_T)
353353
s, t = edge_index(g)
354354
g1 = add_edges(g, s[1:5], t[1:5])
355355
@test g1.num_edges == g.num_edges + 5
356-
g2 = remove_multi_edges(g1, aggr = +)
356+
g2 = coalesce(g1, aggr = +)
357357
@test g2.num_edges == g.num_edges
358358
@test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g))
359359

360360
# Default aggregation is +
361361
g1 = GNNGraph(g1, edata = (e1 = ones(3, g1.num_edges), e2 = 2 * ones(g1.num_edges)))
362362
g1 = set_edge_weight(g1, 3 * ones(g1.num_edges))
363-
g2 = remove_multi_edges(g1)
363+
g2 = coalesce(g1)
364364
@test g2.num_edges == g.num_edges
365365
@test sort_edge_index(edge_index(g2)) == sort_edge_index(edge_index(g))
366366
@test count(g2.edata.e1[:, i] == 2 * ones(3) for i in 1:(g2.num_edges)) == 5
@@ -714,3 +714,51 @@ end
714714
end
715715
end
716716
end
717+
718+
@testitem "graph transform ops set is_coalesced=false" setup=[GraphsTestModule] begin
719+
using .GraphsTestModule
720+
g = rand_graph(5, 10, graph_type=:coo)
721+
g = coalesce(g) # ensure the graph is coalesced to start with
722+
723+
# add_self_loops
724+
g1 = add_self_loops(g)
725+
@test g1.is_coalesced == false
726+
727+
# remove_self_loops
728+
g2 = add_self_loops(g) # ensure there are self-loops to remove
729+
g2 = remove_self_loops(g2)
730+
@test g2.is_coalesced == false
731+
732+
# remove_edges
733+
g3 = remove_edges(g, [1])
734+
@test g3.is_coalesced == false
735+
736+
# add_edges
737+
g4 = add_edges(g, [1], [2])
738+
@test g4.is_coalesced == false
739+
740+
# perturb_edges
741+
g5 = perturb_edges(g, 0.5)
742+
@test g5.is_coalesced == false
743+
744+
# remove_nodes
745+
g6 = remove_nodes(g, [1])
746+
@test g6.is_coalesced == false
747+
748+
# add_nodes
749+
g7 = add_nodes(g, 2)
750+
@test g7.is_coalesced == false
751+
752+
# rand_edge_split returns two graphs
753+
g8a, g8b = rand_edge_split(g, 0.5)
754+
@test g8a.is_coalesced == false
755+
@test g8b.is_coalesced == false
756+
757+
# negative_sample
758+
g9 = negative_sample(g, num_neg_edges=3)
759+
@test g9.is_coalesced == false
760+
761+
# ppr_diffusion
762+
g11 = ppr_diffusion(g)
763+
@test g11.is_coalesced == false
764+
end

0 commit comments

Comments
 (0)