Skip to content

Commit ade8659

Browse files
cleanup
1 parent 59c8a7f commit ade8659

File tree

5 files changed

+40
-15
lines changed

5 files changed

+40
-15
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ version = "0.2.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8-
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
98
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
109
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1110
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -14,7 +13,6 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1413
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
1514
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1615
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
17-
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1816
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1917
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2018
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
@@ -28,7 +26,7 @@ Adapt = "3"
2826
CUDA = "3.3"
2927
ChainRulesCore = "1"
3028
DataStructures = "0.18"
31-
Flux = "0.12"
29+
Flux = "0.12.7"
3230
KrylovKit = "0.5"
3331
LearnBase = "0.4, 0.5"
3432
LightGraphs = "1.3"

src/deprecations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ update_edge(l, e, m) = e
1616
function propagate(l::GNNLayer, g::GNNGraph, aggr, x, e=nothing)
1717
@warn """
1818
Passing a GNNLayer to propagate is deprecated,
19-
you should pass directly the message function.
19+
you should pass the message function directly.
2020
The new signature is `propagate(f, g, aggr; [xi, xj, e])`.
2121
22-
Also the functions `compute_message`, `update_node`,
23-
and `update_edge` have been deprecated. Please
22+
The functions `compute_message`, `update_node`,
23+
and `update_edge` have been deprecated as well. Please
2424
refer to the documentation.
2525
"""
26-
m = apply_edge((a...) -> compute_message(l, a...), g, x, x, e)
26+
m = apply_edges((a...) -> compute_message(l, a...), g, x, x, e)
2727
= aggregate_neighbors(g, aggr, m)
2828
x = update_node(l, x, m̄)
2929
e = update_edge(l, e, m)

src/utils.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,6 @@ ones_like(x::CUMAT_T, T=eltype(x), sz=size(x)) = CUDA.ones(T, sz)
7474

7575
ofeltype(x, y) = convert(float(eltype(x)), y)
7676

77-
# TODO move to flux. fix for https://github.com/FluxML/Flux.jl/issues/1720
78-
Flux._cpu_array(x::AbstractSparseArray) = Flux.adapt(SparseMatrixCSC, x)
79-
80-
# TODO. FIX THIS HACK. CUDA.jl support to sparse matrices is very bad, convert to dense
81-
# Revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
82-
Flux._gpu_array(x::AbstractSparseArray) = CuMatrix(x)
83-
84-
8577
# Considers the src a zero dimensional object.
8678
# Useful for implementing `StatsBase.counts`, `degree`, etc...
8779
# function NNlib.scatter!(op, dst::AbstractArray, src::Number, idx::AbstractArray)

test/deprecations.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
@testset "deprecations" begin
2+
@testset "propagate" begin
3+
struct GCN{A<:AbstractMatrix, B, F} <: GNNLayer
4+
weight::A
5+
bias::B
6+
σ::F
7+
end
8+
9+
Flux.@functor GCN # allow collecting params, gpu movement, etc...
10+
11+
function GCN(ch::Pair{Int,Int}, σ=identity)
12+
in, out = ch
13+
W = Flux.glorot_uniform(out, in)
14+
b = zeros(Float32, out)
15+
GCN(W, b, σ)
16+
end
17+
18+
GraphNeuralNetworks.compute_message(l::GCN, xi, xj, e) = xj
19+
20+
function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T
21+
x, _ = propagate(l, g, +, x)
22+
return l.σ.(l.weight * x .+ l.bias)
23+
end
24+
25+
function new_forward(l, g, x)
26+
x = propagate(copyxj, g, +, xj=x)
27+
return l.σ.(l.weight * x .+ l.bias)
28+
end
29+
30+
g = GNNGraph(random_regular_graph(10, 4), ndata=randn(3, 10))
31+
l = GCN(3 => 5, tanh)
32+
@test l(g, g.ndata.x) new_forward(l, g, g.ndata.x)
33+
end
34+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ tests = [
2424
"layers/conv",
2525
"layers/pool",
2626
"examples/node_classification_cora",
27+
"deprecations",
2728
]
2829

2930
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")

0 commit comments

Comments
 (0)