Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export
ResGatedGraphConv,
SAGEConv,
GMMConv,
EdgeWeightNorm,

# layers/pool
GlobalPool,
Expand Down
66 changes: 66 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1181,3 +1181,69 @@ function Base.show(io::IO, l::GMMConv)
l.residual==true || print(io, ", residual=", l.residual)
print(io, ")")
end

@doc raw"""
EdgeWeightNorm(norm_both = true, eps = 0)

Normalizes positive scalar edge weights on a graph following the form in GCN.

norm_both = `true` yields the following normalization term:
```math
c_{ji} = (\sqrt{\sum_{k\in\mathcal{N}(j)}e_{jk}}\sqrt{\sum_{k\in\mathcal{N}(i)}e_{ki}})
```
norm_both = `false` yields the following normalization term:
```math
c_{ji} = (\sum_{k\in\mathcal{N}(i)}e_{ki})
```
where ``e_{ji}`` is the scalar weight on the edge from node j to node i.

Return value is the normalized weight ``e_{ji} / c_{ji}`` for all edges in vector form.

# Arguments

- `norm_both`: The normalizer as specified above. Default is `true`.
- `eps`: Offset value in the denominator. Default is `0`.

# Examples

```julia
# create data
g = GNNGraph([1,2,3,4,3,6], [2,3,4,5,1,4])
g = add_self_loops(g)

# edge weights
edge_weights = [0.5, 0.6, 0.4, 0.7, 0.9, 0.1, 1, 1, 1, 1, 1, 1]

l = EdgeWeightNorm()
l(g, edge_weights)
```
"""
struct EdgeWeightNorm <: GNNLayer
norm_both::Bool
eps::Float64
end

@functor EdgeWeightNorm

function EdgeWeightNorm(norm_both::Bool = true,
eps::Float64 = 0)
EdgeWeightNorm(norm_both, eps)
end

function (l::EdgeWeightNorm)(g::GNNGraph, edge_weight::AbstractVector)
norm_val = Vector{Float64}()
edge_in, edge_out = edge_index(g)

dg_in = degree(g; dir = :in, edge_weight)
dg_out = degree(g; dir = :out, edge_weight)

for iter in 1:length(edge_weight)
if l.norm_both
push!(norm_val, edge_weight[iter] / (sqrt(dg_out[in[iter]] * dg_in[out[iter]]) + l.eps))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these mutating operations are not AD friendly. I didn't think about it carefully but you should probably use apply_edges here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried doing this using aggregate_neighbours, but ran into the same issue of it not being AD friendly, main issue being that I'm currently computing ∑e_jk and ∑e_ki individually for each e_ji, which I guess is to be done with some sort of matrix multiplication.

I'm not sure yet how its to be done using apply_edges, but I'll look more into it(will try to understand how its done in pyTorch/DGL) and let you know any updates.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these mutating operations are not AD friendly. I didn't think about it carefully but you should probably use apply_edges here

else
push!(norm_val, edge_weight[iter] / (dg_in[out[iter]] + l.eps))
end
end

return norm_val
end