-
Notifications
You must be signed in to change notification settings - Fork 53
Added EdgeWeightNorm layer #158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,7 @@ export | |
ResGatedGraphConv, | ||
SAGEConv, | ||
GMMConv, | ||
EdgeWeightNorm, | ||
|
||
# layers/pool | ||
GlobalPool, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1181,3 +1181,69 @@ function Base.show(io::IO, l::GMMConv) | |
l.residual==true || print(io, ", residual=", l.residual) | ||
print(io, ")") | ||
end | ||
|
||
@doc raw""" | ||
EdgeWeightNorm(norm_both = true, eps = 0) | ||
|
||
Normalizes positive scalar edge weights on a graph following the form in GCN. | ||
|
||
norm_both = `true` yields the following normalization term: | ||
```math | ||
c_{ji} = (\sqrt{\sum_{k\in\mathcal{N}(j)}e_{jk}}\sqrt{\sum_{k\in\mathcal{N}(i)}e_{ki}}) | ||
``` | ||
norm_both = `false` yields the following normalization term: | ||
```math | ||
c_{ji} = (\sum_{k\in\mathcal{N}(i)}e_{ki}) | ||
``` | ||
where ``e_{ji}`` is the scalar weight on the edge from node j to node i. | ||
|
||
Return value is the normalized weight ``e_{ji} / c_{ji}`` for all edges in vector form. | ||
|
||
# Arguments | ||
|
||
- `norm_both`: The normalizer as specified above. Default is `true`. | ||
- `eps`: Offset value in the denominator. Default is `0`. | ||
|
||
# Examples | ||
|
||
```julia | ||
# create data | ||
g = GNNGraph([1,2,3,4,3,6], [2,3,4,5,1,4]) | ||
g = add_self_loops(g) | ||
|
||
# edge weights | ||
edge_weights = [0.5, 0.6, 0.4, 0.7, 0.9, 0.1, 1, 1, 1, 1, 1, 1] | ||
|
||
l = EdgeWeightNorm() | ||
l(g, edge_weights) | ||
``` | ||
""" | ||
struct EdgeWeightNorm <: GNNLayer | ||
norm_both::Bool | ||
eps::Float64 | ||
end | ||
|
||
@functor EdgeWeightNorm | ||
|
||
function EdgeWeightNorm(norm_both::Bool = true, | ||
eps::Float64 = 0) | ||
EdgeWeightNorm(norm_both, eps) | ||
end | ||
|
||
function (l::EdgeWeightNorm)(g::GNNGraph, edge_weight::AbstractVector) | ||
norm_val = Vector{Float64}() | ||
rbSparky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
edge_in, edge_out = edge_index(g) | ||
|
||
dg_in = degree(g; dir = :in, edge_weight) | ||
dg_out = degree(g; dir = :out, edge_weight) | ||
|
||
for iter in 1:length(edge_weight) | ||
if l.norm_both | ||
push!(norm_val, edge_weight[iter] / (sqrt(dg_out[in[iter]] * dg_in[out[iter]]) + l.eps)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these mutating operations are not AD friendly. I didn't think about it carefully but you should probably use apply_edges here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried doing this using aggregate_neighbours, but ran into the same issue of it not being AD friendly, main issue being that I'm currently computing I'm not sure yet how its to be done using apply_edges, but I'll look more into it(will try to understand how its done in pyTorch/DGL) and let you know any updates. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these mutating operations are not AD friendly. I didn't think about it carefully but you should probably use apply_edges here |
||
else | ||
push!(norm_val, edge_weight[iter] / (dg_in[out[iter]] + l.eps)) | ||
end | ||
end | ||
|
||
return norm_val | ||
end |
Uh oh!
There was an error while loading. Please reload this page.