Skip to content

Commit e2623eb

Browse files
Added perturb_edges function (#423)
* add edge perturbation * add to gnngraphs * Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> * loop fix * Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello <[email protected]> * Update transform.jl * Update transform.jl * gpu compat * include package * Update test/GNNGraphs/transform.jl --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 0f8e13c commit e2623eb

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import KrylovKit
1414
using ChainRulesCore
1515
using LinearAlgebra, Random, Statistics
1616
import MLUtils
17-
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk
17+
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, rand_like
1818
import Functors
1919

2020
include("chainrules.jl") # hacks for differentiability
@@ -78,6 +78,7 @@ export add_nodes,
7878
to_bidirected,
7979
to_unidirected,
8080
random_walk_pe,
81+
perturb_edges,
8182
remove_nodes,
8283
ppr_diffusion,
8384
drop_nodes,

src/GNNGraphs/transform.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,72 @@ function add_edges(g::GNNHeteroGraph{<:COO_T},
502502
ntypes, etypes)
503503
end
504504

505+
"""
506+
perturb_edges([rng], g::GNNGraph, perturb_ratio)
507+
508+
Perturb the graph `g` by adding random edges, based on a specified `perturb_ratio`. The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. These new edges are added without creating self-loops. Optionally, a random `seed` can be provided to ensure reproducible perturbations.
509+
510+
The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges. The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges.
511+
512+
# Parameters
513+
- `g::GNNGraph`: The graph to be perturbed.
514+
- `perturb_ratio`: The ratio of the number of new edges to add relative to the current number of edges in the graph. For example, a `perturb_ratio` of 0.1 means that 10% of the current number of edges will be added as new random edges.
515+
- `seed=123`: An optional seed for the random number generator to ensure reproducible results.
516+
517+
# Examples
518+
519+
```julia
520+
julia> g = GNNGraph((s, t, w))
521+
GNNGraph:
522+
num_nodes: 4
523+
num_edges: 5
524+
525+
julia> perturbed_g = perturb_edges(g, 0.2)
526+
GNNGraph:
527+
num_nodes: 4
528+
num_edges: 6 # One new edge added if the original graph had 5 edges, as 0.2 of 5 is 1.
529+
530+
julia> perturbed_g = perturb_edges(g, 0.5, seed=42)
531+
GNNGraph:
532+
num_nodes: 4
533+
num_edges: 7 # Two new edges added if the original graph had 5 edges, as 0.5 of 5 rounds to 2.
534+
```
535+
"""
536+
function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; rng::AbstractRNG = Random.default_rng())
537+
@assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1"
538+
539+
Random.seed!(rng)
540+
541+
num_current_edges = g.num_edges
542+
num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio)
543+
544+
if num_edges_to_add == 0
545+
return g
546+
end
547+
548+
num_nodes = g.num_nodes
549+
@assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges"
550+
551+
snew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes)
552+
tnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes)
553+
554+
mask_loops = snew .!= tnew
555+
snew = snew[mask_loops]
556+
tnew = tnew[mask_loops]
557+
558+
while length(snew) < num_edges_to_add
559+
n = num_edges_to_add - length(snew)
560+
snewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes)
561+
tnewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes)
562+
mask_new_loops = snewnew .!= tnewnew
563+
snewnew = snewnew[mask_new_loops]
564+
tnewnew = tnewnew[mask_new_loops]
565+
snew = [snew; snewnew]
566+
tnew = [tnew; tnewnew]
567+
end
568+
569+
return add_edges(g, (snew, tnew, nothing))
570+
end
505571

506572

507573
### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable

test/GNNGraphs/transform.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,14 @@ end
177177
end
178178
end
179179

180+
@testset "perturb_edges" begin if GRAPH_T == :coo
181+
s, t = [1, 2, 3, 4, 5], [2, 3, 4, 5, 1]
182+
g = GNNGraph((s, t))
183+
rng = MersenneTwister(42)
184+
g_per = perturb_edges(g, 0.5, rng=rng)
185+
@test g_per.num_edges == 8
186+
end end
187+
180188
@testset "remove_nodes" begin if GRAPH_T == :coo
181189
#single node
182190
s = [1, 1, 2, 3]

0 commit comments

Comments
 (0)