Skip to content

Commit 374d8fb

Browse files
Fixes and more tests in remove_nodes function (#424)
* need to fix * bounds check * added more tests * Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update transform.jl (re added previous change) --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 1fb54b6 commit 374d8fb

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

src/GNNGraphs/transform.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,16 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
219219
w = get_edge_weight(g)
220220
edata = g.edata
221221
ndata = g.ndata
222-
223-
edges_to_remove_s = findall(x -> nodes_to_remove[searchsortedlast(nodes_to_remove, x)] == x, s)
224-
edges_to_remove_t = findall(x -> nodes_to_remove[searchsortedlast(nodes_to_remove, x)] == x, t)
222+
223+
function find_edges_to_remove(nodes, nodes_to_remove)
224+
return findall(node_id -> begin
225+
idx = searchsortedlast(nodes_to_remove, node_id)
226+
idx >= 1 && idx <= length(nodes_to_remove) && nodes_to_remove[idx] == node_id
227+
end, nodes)
228+
end
229+
230+
edges_to_remove_s = find_edges_to_remove(s, nodes_to_remove)
231+
edges_to_remove_t = find_edges_to_remove(t, nodes_to_remove)
225232
edges_to_remove = union(edges_to_remove_s, edges_to_remove_t)
226233

227234
mask_edges_to_keep = trues(length(s))

test/GNNGraphs/transform.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ end
150150
end
151151

152152
@testset "remove_nodes" begin if GRAPH_T == :coo
153+
#single node
153154
s = [1, 1, 2, 3]
154155
t = [2, 3, 4, 5]
155156
eweights = [0.1, 0.2, 0.3, 0.4]
@@ -179,6 +180,35 @@ end
179180
@test eweights_new == eweightstest
180181
@test ndata_new == ndatatest
181182
@test edata_new == edatatest
183+
184+
# multiple nodes
185+
s = [1, 5, 2, 3]
186+
t = [2, 3, 4, 5]
187+
eweights = [0.1, 0.2, 0.3, 0.4]
188+
ndata = [1.0, 2.0, 3.0, 4.0, 5.0]
189+
edata = ['a', 'b', 'c', 'd']
190+
191+
g = GNNGraph(s, t, eweights, ndata = ndata, edata = edata, graph_type = GRAPH_T)
192+
193+
gnew = remove_nodes(g, [1,4])
194+
snew = [3,2]
195+
tnew = [2,3]
196+
eweights_new = [0.2,0.4]
197+
ndata_new = [2.0,3.0,5.0]
198+
edata_new = ['b','d']
199+
200+
stest, ttest = edge_index(gnew)
201+
eweightstest = get_edge_weight(gnew)
202+
ndatatest = gnew.ndata.x
203+
edatatest = gnew.edata.e
204+
205+
@test gnew.num_edges == 2
206+
@test gnew.num_nodes == 3
207+
@test snew == stest
208+
@test tnew == ttest
209+
@test eweights_new == eweightstest
210+
@test ndata_new == ndatatest
211+
@test edata_new == edatatest
182212
end end
183213

184214
@testset "add_nodes" begin if GRAPH_T == :coo

0 commit comments

Comments
 (0)