Skip to content

Commit bcce0cf

Browse files
Add drop_nodes transform (#426)
* drop node * tests * Update transform.jl * Update transform.jl * added to gnngraphs * error in test? * Update src/GNNGraphs/transform.jl float32 args Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/transform.jl arg fix Co-authored-by: Carlo Lucibello <[email protected]> --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 36e8373 commit bcce0cf

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ export add_nodes,
7979
to_unidirected,
8080
random_walk_pe,
8181
remove_nodes,
82+
drop_nodes,
8283
# from Flux
8384
batch,
8485
unbatch,

src/GNNGraphs/transform.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,38 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
306306
ndata, edata, g.gdata)
307307
end
308308

309+
"""
310+
drop_nodes(g::GNNGraph{<:COO_T}, p)
311+
312+
Randomly drop nodes (and their associated edges) from a GNNGraph based on a given probability.
313+
Dropping nodes is a technique that can be used for graph data augmentation, refering paper [DropNode](https://arxiv.org/pdf/2008.12578.pdf).
314+
315+
# Arguments
316+
- `g`: The input graph from which nodes (and their associated edges) will be dropped.
317+
- `p`: The probability of dropping each node. Default value is `0.5`.
318+
319+
# Returns
320+
A modified GNNGraph with nodes (and their associated edges) dropped based on the given probability.
321+
322+
# Example
323+
```julia
324+
using GraphNeuralNetworks
325+
# Construct a GNNGraph
326+
g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1], num_nodes=3)
327+
# Drop nodes with a probability of 0.5
328+
g_new = drop_node(g, 0.5)
329+
println(g_new)
330+
```
331+
"""
332+
function drop_nodes(g::GNNGraph{<:COO_T}, p = 0.5)
333+
num_nodes = g.num_nodes
334+
nodes_to_remove = filter(_ -> rand() < p, 1:num_nodes)
335+
336+
new_g = remove_nodes(g, nodes_to_remove)
337+
338+
return new_g
339+
end
340+
309341
"""
310342
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
311343
add_edges(g::GNNGraph, (s, t); [edata])

test/GNNGraphs/transform.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,24 @@ end
239239
@test edata_new == edatatest
240240
end end
241241

242+
@testset "drop_nodes" begin
243+
if GRAPH_T == :coo
244+
Random.seed!(42)
245+
s = [1, 1, 2, 3]
246+
t = [2, 3, 4, 5]
247+
g = GNNGraph(s, t, graph_type = GRAPH_T)
248+
249+
gnew = drop_nodes(g, Float32(0.5))
250+
@test gnew.num_nodes == 3
251+
252+
gnew = drop_nodes(g, Float32(1.0))
253+
@test gnew.num_nodes == 0
254+
255+
gnew = drop_nodes(g, Float32(0.0))
256+
@test gnew.num_nodes == 5
257+
end
258+
end
259+
242260
@testset "add_nodes" begin if GRAPH_T == :coo
243261
g = rand_graph(6, 4, ndata = rand(2, 6), graph_type = GRAPH_T)
244262
gnew = add_nodes(g, 5, ndata = ones(2, 5))

0 commit comments

Comments
 (0)