Skip to content

Commit 3bcafbe

Browse files
Added Personalized PageRank Diffusion [ppr_diffusion function] (#427)
* add ppr diffusion * add ppr diffusion * add function to GNNGraphs.jl * :coo * Update transform.jl * try * Made function non-mutating uses SparseArrays * Update src/GNNGraphs/transform.jl rename args Co-authored-by: Carlo Lucibello <[email protected]> * Update test/GNNGraphs/transform.jl clean code Co-authored-by: Carlo Lucibello <[email protected]> * Update test/GNNGraphs/transform.jl remove unneeded line Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/transform.jl args fix Co-authored-by: Carlo Lucibello <[email protected]> * Update src/GNNGraphs/transform.jl rename var Co-authored-by: Carlo Lucibello <[email protected]> * empty weights * indent * fixes * Update test/GNNGraphs/transform.jl --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent bcce0cf commit 3bcafbe

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ export add_nodes,
7979
to_unidirected,
8080
random_walk_pe,
8181
remove_nodes,
82+
ppr_diffusion,
8283
drop_nodes,
8384
# from Flux
8485
batch,

src/GNNGraphs/transform.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,3 +1168,49 @@ ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci
11681168
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
11691169
@non_differentiable dense_zeros_like(x...)
11701170

1171+
"""
1172+
ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph
1173+
1174+
Calculates the Personalized PageRank (PPR) diffusion based on the edge weight matrix of a GNNGraph and updates the graph with new edge weights derived from the PPR matrix.
1175+
References paper: [The pagerank citation ranking: Bringing order to the web](http://ilpubs.stanford.edu:8090/422)
1176+
1177+
1178+
The function performs the following steps:
1179+
1. Constructs a modified adjacency matrix `A` using the graph's edge weights, where `A` is adjusted by `(α - 1) * A + I`, with `α` being the damping factor (`alpha_f32`) and `I` the identity matrix.
1180+
2. Normalizes `A` to ensure each column sums to 1, representing transition probabilities.
1181+
3. Applies the PPR formula `α * (I + (α - 1) * A)^-1` to compute the diffusion matrix.
1182+
4. Updates the original edge weights of the graph based on the PPR diffusion matrix, assigning new weights for each edge from the PPR matrix.
1183+
1184+
# Arguments
1185+
- `g::GNNGraph`: The input graph for which PPR diffusion is to be calculated. It should have edge weights available.
1186+
- `alpha_f32::Float32`: The damping factor used in PPR calculation, controlling the teleport probability in the random walk. Defaults to `0.85f0`.
1187+
1188+
# Returns
1189+
- A new `GNNGraph` instance with the same structure as `g` but with updated edge weights according to the PPR diffusion calculation.
1190+
"""
1191+
function ppr_diffusion(g::GNNGraph{<:COO_T}; alpha = 0.85f0)
1192+
s, t = edge_index(g)
1193+
w = get_edge_weight(g)
1194+
if isnothing(w)
1195+
w = ones(Float32, g.num_edges)
1196+
end
1197+
1198+
N = g.num_nodes
1199+
1200+
initial_A = sparse(t, s, w, N, N)
1201+
scaled_A = (Float32(alpha) - 1) * initial_A
1202+
1203+
I_sparse = sparse(Diagonal(ones(Float32, N)))
1204+
A_sparse = I_sparse + scaled_A
1205+
1206+
A_dense = Matrix(A_sparse)
1207+
1208+
PPR = alpha * inv(A_dense)
1209+
1210+
new_w = [PPR[dst, src] for (src, dst) in zip(s, t)]
1211+
1212+
return GNNGraph((s, t, new_w),
1213+
g.num_nodes, length(s), g.num_graphs,
1214+
g.graph_indicator,
1215+
g.ndata, g.edata, g.gdata)
1216+
end

test/GNNGraphs/transform.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,4 +595,24 @@ end
595595

596596
@test g.graph[(:A, :to1, :A)][3] == vcat([2, 2, 2], fill(1, n))
597597
end
598+
end
599+
600+
@testset "ppr_diffusion" begin
601+
if GRAPH_T == :coo
602+
s = [1, 1, 2, 3]
603+
t = [2, 3, 4, 5]
604+
eweights = [0.1, 0.2, 0.3, 0.4]
605+
606+
g = GNNGraph(s, t, eweights)
607+
608+
g_new = ppr_diffusion(g)
609+
w_new = get_edge_weight(g_new)
610+
611+
check_ew = Float32[0.012749999
612+
0.025499998
613+
0.038249996
614+
0.050999995]
615+
616+
@test w_new check_ew
617+
end
598618
end

0 commit comments

Comments
 (0)