Skip to content

Commit 0151574

Browse files
renmae message to comput_message
1 parent 95d0c50 commit 0151574

File tree

4 files changed

+16
-16
lines changed

4 files changed

+16
-16
lines changed

src/layers/conv.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
4545
l.σ.(l.weight * x *.+ l.bias)
4646
end
4747

48-
message(l::GCNConv, xi, xj, eij) = xj
48+
compute_message(l::GCNConv, xi, xj, eij) = xj
4949
update_node(l::GCNConv, m, x) = m
5050

5151
function (l::GCNConv)(g::GNNGraph, x::CuMatrix{T}) where T
@@ -176,7 +176,7 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+;
176176
GraphConv(W1, W2, b, σ, aggr)
177177
end
178178

179-
message(l::GraphConv, x_i, x_j, e_ij) = x_j
179+
compute_message(l::GraphConv, x_i, x_j, e_ij) = x_j
180180
update_node(l::GraphConv, m, x) = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
181181

182182
function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
@@ -326,7 +326,7 @@ function GatedGraphConv(out_ch::Int, num_layers::Int;
326326
end
327327

328328

329-
message(l::GatedGraphConv, x_i, x_j, e_ij) = x_j
329+
compute_message(l::GatedGraphConv, x_i, x_j, e_ij) = x_j
330330
update_node(l::GatedGraphConv, m, x) = m
331331

332332
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
@@ -381,7 +381,7 @@ end
381381

382382
EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr)
383383

384-
message(l::EdgeConv, x_i, x_j, e_ij) = l.nn(vcat(x_i, x_j .- x_i))
384+
compute_message(l::EdgeConv, x_i, x_j, e_ij) = l.nn(vcat(x_i, x_j .- x_i))
385385

386386
update_node(l::EdgeConv, m, x) = m
387387

@@ -426,7 +426,7 @@ function GINConv(nn; eps=0f0)
426426
GINConv(nn, eps)
427427
end
428428

429-
message(l::GINConv, x_i, x_j, e_ij) = x_j
429+
compute_message(l::GINConv, x_i, x_j, e_ij) = x_j
430430
update_node(l::GINConv, m, x) = l.nn((1 + l.eps) * x + m)
431431

432432
function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)

src/msgpass.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ with updated feautres.
1212
The computational steps are the following:
1313
1414
```julia
15-
m = compute_batch_message(l, g, x, e) # calls `message`
15+
m = compute_batch_message(l, g, x, e) # calls `compute_message`
1616
m̄ = aggregate_neighbors(l, aggr, g, m)
1717
x′ = update_node(l, m̄, x)
1818
e′ = update_edge(l, m, e)
@@ -72,7 +72,7 @@ end
7272
## Step 1.
7373

7474
"""
75-
message(l, x_i, x_j, [e_ij])
75+
compute_message(l, x_i, x_j, [e_ij])
7676
7777
Message function for the message-passing scheme,
7878
returning the message from node `j` to node `i` .
@@ -81,7 +81,7 @@ from the neighborhood of `i` will later be aggregated
8181
in order to update (see [`update_node`](@ref)) the features of node `i`.
8282
8383
The function operates on batches of edges, therefore
84-
`x_i`, `x_j`, and `e_ij` are tensors whose last dimention
84+
`x_i`, `x_j`, and `e_ij` are tensors whose last dimension
8585
is the batch size.
8686
8787
By default, the function returns `x_j`.
@@ -98,8 +98,8 @@ See also [`update_node`](@ref) and [`propagate`](@ref).
9898
"""
9999
function message end
100100

101-
@inline message(l, x_i, x_j, e_ij) = message(l, x_i, x_j)
102-
@inline message(l, x_i, x_j) = x_j
101+
@inline compute_message(l, x_i, x_j, e_ij) = compute_message(l, x_i, x_j)
102+
@inline compute_message(l, x_i, x_j) = x_j
103103

104104
_gather(x, i) = NNlib.gather(x, i)
105105
_gather(x::Nothing, i) = nothing
@@ -108,7 +108,7 @@ function compute_batch_message(l, g, x, e)
108108
s, t = edge_index(g)
109109
xi = _gather(x, t)
110110
xj = _gather(x, s)
111-
m = message(l, xi, xj, e)
111+
m = compute_message(l, xi, xj, e)
112112
return m
113113
end
114114

test/cuda/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
@functor NewCudaLayer
1818

1919
(l::NewCudaLayer)(g, X) = GraphNeuralNetworks.propagate(l, g, +, X)
20-
GraphNeuralNetworks.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
20+
GraphNeuralNetworks.compute_message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
2121
GraphNeuralNetworks.update_node(::NewCudaLayer, m, x) = m
2222

2323
X = rand(T, in_channel, N) |> gpu

test/msgpass.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
@testset "custom message and neighbor aggregation" begin
5252
struct NewLayer3{G} end
5353

54-
GraphNeuralNetworks.message(l::NewLayer3{GRAPH_T}, xi, xj, e) = ones(T, out_channel, size(e,2))
54+
GraphNeuralNetworks.compute_message(l::NewLayer3{GRAPH_T}, xi, xj, e) = ones(T, out_channel, size(e,2))
5555
(l::NewLayer3{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, +)
5656

5757

@@ -70,7 +70,7 @@
7070
struct NewLayer4{G} end
7171

7272
GraphNeuralNetworks.update_edge(l::NewLayer4{GRAPH_T}, m, e) = m
73-
GraphNeuralNetworks.message(l::NewLayer4{GRAPH_T}, xi, xj, e) = ones(T, out_channel, size(e,2))
73+
GraphNeuralNetworks.compute_message(l::NewLayer4{GRAPH_T}, xi, xj, e) = ones(T, out_channel, size(e,2))
7474
(l::NewLayer4{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, +)
7575

7676
l = NewLayer4{GRAPH_T}()
@@ -89,7 +89,7 @@
8989

9090
GraphNeuralNetworks.update_node(l::NewLayer5{GRAPH_T}, m̄, xi) = rand(T, 2*out_channel, size(xi, 2))
9191
GraphNeuralNetworks.update_edge(l::NewLayer5{GRAPH_T}, m, e) = m
92-
GraphNeuralNetworks.message(l::NewLayer5{GRAPH_T}, xi, xj, e) = ones(T, out_channel, size(e,2))
92+
GraphNeuralNetworks.compute_message(l::NewLayer5{GRAPH_T}, xi, xj, e) = ones(T, out_channel, size(e,2))
9393
(l::NewLayer5{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, +)
9494

9595
l = NewLayer5{GRAPH_T}()
@@ -110,7 +110,7 @@
110110

111111
NewLayerW(in, out) = NewLayerW{GRAPH_T}(randn(T, out, in))
112112

113-
GraphNeuralNetworks.message(l::NewLayerW{GRAPH_T}, x_i, x_j, e_ij) = l.weight * x_j
113+
GraphNeuralNetworks.compute_message(l::NewLayerW{GRAPH_T}, x_i, x_j, e_ij) = l.weight * x_j
114114
GraphNeuralNetworks.update_node(l::NewLayerW{GRAPH_T}, m, x) = l.weight * x + m
115115

116116
l = NewLayerW(in_channel, out_channel)

0 commit comments

Comments
 (0)