Skip to content

Commit c82efa0

Browse files
[GNNGraphs] implement remove_edges(g, p) (#474)
* new drop edge * fix * Update GNNGraphs/src/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update GNNGraphs/src/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 4b4477e commit c82efa0

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

GNNGraphs/src/transform.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,14 @@ end
151151

152152
"""
153153
remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer})
154+
remove_edges(g::GNNGraph, p=0.5)
154155
155-
Remove specified edges from a GNNGraph.
156+
Remove specified edges from a GNNGraph, either by specifying edge indices or by randomly removing edges with a given probability.
156157
157158
# Arguments
158159
- `g`: The input graph from which edges will be removed.
159-
- `edges_to_remove`: Vector of edge indices to be removed.
160+
- `edges_to_remove`: Vector of edge indices to be removed. This argument is only required for the first method.
161+
- `p`: Probability of removing each edge. This argument is only required for the second method and defaults to 0.5.
160162
161163
# Returns
162164
A new GNNGraph with the specified edges removed.
@@ -178,6 +180,14 @@ julia> g_new
178180
GNNGraph:
179181
num_nodes: 3
180182
num_edges: 4
183+
184+
# Remove edges with a probability of 0.5
185+
julia> g_new = remove_edges(g, 0.5);
186+
187+
julia> g_new
188+
GNNGraph:
189+
num_nodes: 3
190+
num_edges: 2
181191
```
182192
"""
183193
function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer})
@@ -200,6 +210,13 @@ function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:In
200210
g.ndata, edata, g.gdata)
201211
end
202212

213+
214+
function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5)
215+
num_edges = g.num_edges
216+
edges_to_remove = filter(_ -> rand() < p, 1:num_edges)
217+
return remove_edges(g, edges_to_remove)
218+
end
219+
203220
"""
204221
remove_multi_edges(g::GNNGraph; aggr=+)
205222

GNNGraphs/test/transform.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ end
126126
@test new_t == [4]
127127
@test new_w == [0.3]
128128
@test new_edata == ['c']
129+
130+
# drop with probability
131+
gnew = remove_edges(g, Float32(1.0))
132+
@test gnew.num_edges == 0
133+
134+
gnew = remove_edges(g, Float32(0.0))
135+
@test gnew.num_edges == g.num_edges
129136
end
130137
end
131138

0 commit comments

Comments
 (0)