Skip to content

Commit dda020c

Browse files
authored
Added remove_nodes function (#420)
* remove nodes * Update transform.jl * remove nodes * Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> * fix * made efficient using searchsortedlast * more tests * remove ndata getobs * Update test/GNNGraphs/transform.jl ---------
1 parent 18c4606 commit dda020c

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ export add_nodes,
7777
to_bidirected,
7878
to_unidirected,
7979
random_walk_pe,
80+
remove_nodes,
8081
# from Flux
8182
batch,
8283
unbatch,

src/GNNGraphs/transform.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,66 @@ function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +)
188188
g.ndata, edata, g.gdata)
189189
end
190190

191+
"""
192+
remove_nodes(g::GNNGraph, nodes_to_remove::AbstractVector)
193+
194+
Remove specified nodes, and their associated edges, from a GNNGraph. This operation reindexes the remaining nodes to maintain a continuous sequence of node indices, starting from 1. Similarly, edges are reindexed to account for the removal of edges connected to the removed nodes.
195+
196+
# Arguments
197+
- `g`: The input graph from which nodes (and their edges) will be removed.
198+
- `nodes_to_remove`: Vector of node indices to be removed.
199+
200+
# Returns
201+
A new GNNGraph with the specified nodes and all edges associated with these nodes removed.
202+
203+
# Example
204+
```julia
205+
using GraphNeuralNetworks
206+
207+
g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1])
208+
209+
# Remove nodes with indices 2 and 3, for example
210+
g_new = remove_nodes(g, [2, 3])
211+
212+
# g_new now does not contain nodes 2 and 3, and any edges that were connected to these nodes.
213+
println(g_new)
214+
```
215+
"""
216+
function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
217+
nodes_to_remove = sort(union(nodes_to_remove))
218+
s, t = edge_index(g)
219+
w = get_edge_weight(g)
220+
edata = g.edata
221+
ndata = g.ndata
222+
223+
edges_to_remove_s = findall(x -> nodes_to_remove[searchsortedlast(nodes_to_remove, x)] == x, s)
224+
edges_to_remove_t = findall(x -> nodes_to_remove[searchsortedlast(nodes_to_remove, x)] == x, t)
225+
edges_to_remove = union(edges_to_remove_s, edges_to_remove_t)
226+
227+
mask_edges_to_keep = trues(length(s))
228+
mask_edges_to_keep[edges_to_remove] .= false
229+
s = s[mask_edges_to_keep]
230+
t = t[mask_edges_to_keep]
231+
232+
w = isnothing(w) ? nothing : getobs(w, mask_edges_to_keep)
233+
234+
for node in sort(nodes_to_remove, rev=true)
235+
s[s .> node] .-= 1
236+
t[t .> node] .-= 1
237+
end
238+
239+
nodes_to_keep = setdiff(1:g.num_nodes, nodes_to_remove)
240+
ndata = getobs(ndata, nodes_to_keep)
241+
edata = getobs(edata, mask_edges_to_keep)
242+
243+
num_nodes = g.num_nodes - length(nodes_to_remove)
244+
245+
return GNNGraph((s, t, w),
246+
num_nodes, length(s), g.num_graphs,
247+
g.graph_indicator,
248+
ndata, edata, g.gdata)
249+
end
250+
191251
"""
192252
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
193253
add_edges(g::GNNGraph, (s, t); [edata])

test/GNNGraphs/transform.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,37 @@ end
149149
end
150150
end
151151

152+
@testset "remove_nodes" begin if GRAPH_T == :coo
153+
s = [1, 1, 2, 3]
154+
t = [2, 3, 4, 5]
155+
eweights = [0.1, 0.2, 0.3, 0.4]
156+
ndata = [1.0, 2.0, 3.0, 4.0, 5.0]
157+
edata = ['a', 'b', 'c', 'd']
158+
159+
g = GNNGraph(s, t, eweights, ndata = ndata, edata = edata, graph_type = GRAPH_T)
160+
161+
gnew = remove_nodes(g, [1])
162+
163+
snew = [1, 2]
164+
tnew = [3, 4]
165+
eweights_new = [0.3, 0.4]
166+
ndata_new = [2.0, 3.0, 4.0, 5.0]
167+
edata_new = ['c', 'd']
168+
169+
stest, ttest = edge_index(gnew)
170+
eweightstest = get_edge_weight(gnew)
171+
ndatatest = gnew.ndata.x
172+
edatatest = gnew.edata.e
173+
174+
175+
@test gnew.num_edges == 2
176+
@test gnew.num_nodes == 4
177+
@test snew == stest
178+
@test tnew == ttest
179+
@test eweights_new == eweightstest
180+
@test ndata_new == ndatatest
181+
@test edata_new == edatatest
182+
end end
152183

153184
@testset "add_nodes" begin if GRAPH_T == :coo
154185
g = rand_graph(6, 4, ndata = rand(2, 6), graph_type = GRAPH_T)

0 commit comments

Comments
 (0)