Skip to content

Commit bdc1604

Browse files
drop_nodes(g, p) -> remove_nodes(g, p)
1 parent 43d4ab0 commit bdc1604

File tree

3 files changed

+19
-28
lines changed

3 files changed

+19
-28
lines changed

GNNGraphs/src/GNNGraphs.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ export add_nodes,
8080
perturb_edges,
8181
remove_nodes,
8282
ppr_diffusion,
83-
drop_nodes,
8483
# from MLUtils
8584
batch,
8685
unbatch,

GNNGraphs/src/transform.jl

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -307,35 +307,27 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
307307
end
308308

309309
"""
310-
drop_nodes(g::GNNGraph{<:COO_T}, p)
310+
remove_nodes(g::GNNGraph, p)
311311
312-
Randomly drop nodes (and their associated edges) from a GNNGraph based on a given probability.
313-
Dropping nodes is a technique that can be used for graph data augmentation, refering paper [DropNode](https://arxiv.org/pdf/2008.12578.pdf).
312+
Returns a new graph obtained by dropping nodes from `g` with independent probabilities `p`.
314313
315-
# Arguments
316-
- `g`: The input graph from which nodes (and their associated edges) will be dropped.
317-
- `p`: The probability of dropping each node. Default value is `0.5`.
318-
319-
# Returns
320-
A modified GNNGraph with nodes (and their associated edges) dropped based on the given probability.
314+
# Examples
321315
322-
# Example
323316
```julia
324-
using GraphNeuralNetworks
325-
# Construct a GNNGraph
326-
g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1], num_nodes=3)
327-
# Drop nodes with a probability of 0.5
328-
g_new = drop_node(g, 0.5)
329-
println(g_new)
317+
julia> g = GNNGraph([1, 1, 2, 2, 3, 4], [1, 2, 3, 1, 3, 1])
318+
GNNGraph:
319+
num_nodes: 4
320+
num_edges: 6
321+
322+
julia> g_new = remove_nodes(g, 0.5)
323+
GNNGraph:
324+
num_nodes: 2
325+
num_edges: 2
330326
```
331327
"""
332-
function drop_nodes(g::GNNGraph{<:COO_T}, p = 0.5)
333-
num_nodes = g.num_nodes
334-
nodes_to_remove = filter(_ -> rand() < p, 1:num_nodes)
335-
336-
new_g = remove_nodes(g, nodes_to_remove)
337-
338-
return new_g
328+
function remove_nodes(g::GNNGraph, p::AbstractFloat)
329+
nodes_to_remove = filter(_ -> rand() < p, 1:g.num_nodes)
330+
return remove_nodes(g, nodes_to_remove)
339331
end
340332

341333
"""

GNNGraphs/test/transform.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,20 +247,20 @@ end end
247247
@test edata_new == edatatest
248248
end end
249249

250-
@testset "drop_nodes" begin
250+
@testset "remove_nodes(g, p)" begin
251251
if GRAPH_T == :coo
252252
Random.seed!(42)
253253
s = [1, 1, 2, 3]
254254
t = [2, 3, 4, 5]
255255
g = GNNGraph(s, t, graph_type = GRAPH_T)
256256

257-
gnew = drop_nodes(g, Float32(0.5))
257+
gnew = remove_nodes(g, 0.5)
258258
@test gnew.num_nodes == 3
259259

260-
gnew = drop_nodes(g, Float32(1.0))
260+
gnew = remove_nodes(g, 1.0)
261261
@test gnew.num_nodes == 0
262262

263-
gnew = drop_nodes(g, Float32(0.0))
263+
gnew = remove_nodes(g, 0.0)
264264
@test gnew.num_nodes == 5
265265
end
266266
end

0 commit comments

Comments
 (0)