Skip to content

Commit 95a90fc

Browse files
Added remove_edges function (#414)
* added remove edge function * tests * added remove edge function * fix * fix * fix * fix * fix * Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update transform.jl * tests final * Update Project.toml * Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> * done * fixes * more tests --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 374d8fb commit 95a90fc

File tree

4 files changed

+81
-2
lines changed

4 files changed

+81
-2
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ Manifest.toml
88
/docs/build/
99
.vscode
1010
LocalPreferences.toml
11-
.DS_Store
12-
/test.jl
11+
.DS_Store

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ export add_nodes,
7272
negative_sample,
7373
rand_edge_split,
7474
remove_self_loops,
75+
remove_edges,
7576
remove_multi_edges,
7677
set_edge_weight,
7778
to_bidirected,

src/GNNGraphs/transform.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,57 @@ function remove_self_loops(g::GNNGraph{<:ADJMAT_T})
149149
g.ndata, g.edata, g.gdata)
150150
end
151151

152+
"""
153+
remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer})
154+
155+
Remove specified edges from a GNNGraph.
156+
157+
# Arguments
158+
- `g`: The input graph from which edges will be removed.
159+
- `edges_to_remove`: Vector of edge indices to be removed.
160+
161+
# Returns
162+
A new GNNGraph with the specified edges removed.
163+
164+
# Example
165+
```julia
166+
julia> using GraphNeuralNetworks
167+
168+
# Construct a GNNGraph
169+
julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1])
170+
GNNGraph:
171+
num_nodes: 3
172+
num_edges: 5
173+
174+
# Remove the second edge
175+
julia> g_new = remove_edges(g, [2]);
176+
177+
julia> g_new
178+
GNNGraph:
179+
num_nodes: 3
180+
num_edges: 4
181+
```
182+
"""
183+
function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer})
184+
s, t = edge_index(g)
185+
w = get_edge_weight(g)
186+
edata = g.edata
187+
188+
mask_to_keep = trues(length(s))
189+
190+
mask_to_keep[edges_to_remove] .= false
191+
192+
s = s[mask_to_keep]
193+
t = t[mask_to_keep]
194+
edata = getobs(edata, mask_to_keep)
195+
w = isnothing(w) ? nothing : getobs(w, mask_to_keep)
196+
197+
return GNNGraph((s, t, w),
198+
g.num_nodes, length(s), g.num_graphs,
199+
g.graph_indicator,
200+
g.ndata, edata, g.gdata)
201+
end
202+
152203
"""
153204
remove_multi_edges(g::GNNGraph; aggr=+)
154205

test/GNNGraphs/transform.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,34 @@ end
101101
@test nodemap == 1:(g1.num_nodes)
102102
end
103103

104+
@testset "remove_edges" begin
105+
if GRAPH_T == :coo
106+
s = [1, 1, 2, 3]
107+
t = [2, 3, 4, 5]
108+
w = [0.1, 0.2, 0.3, 0.4]
109+
edata = ['a', 'b', 'c', 'd']
110+
g = GNNGraph(s, t, w, edata = edata, graph_type = GRAPH_T)
111+
112+
# single edge removal
113+
gnew = remove_edges(g, [1])
114+
new_s, new_t = edge_index(gnew)
115+
@test gnew.num_edges == 3
116+
@test new_s == s[2:end]
117+
@test new_t == t[2:end]
118+
119+
# multiple edge removal
120+
gnew = remove_edges(g, [1,2,4])
121+
new_s, new_t = edge_index(gnew)
122+
new_w = get_edge_weight(gnew)
123+
new_edata = gnew.edata.e
124+
@test gnew.num_edges == 1
125+
@test new_s == [2]
126+
@test new_t == [4]
127+
@test new_w == [0.3]
128+
@test new_edata == ['c']
129+
end
130+
end
131+
104132
@testset "add_edges" begin
105133
if GRAPH_T == :coo
106134
s = [1, 1, 2, 3]

0 commit comments

Comments
 (0)