Skip to content

Commit 57bcddc

Browse files
modernize examples (#295)
1 parent 376586b commit 57bcddc

File tree

7 files changed

+2761
-14
lines changed

7 files changed

+2761
-14
lines changed

examples/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1313
[compat]
1414
DiffEqFlux = "1.45"
1515
Flux = "0.13"
16-
GraphNeuralNetworks = "0.5"
16+
GraphNeuralNetworks = "0.6"
1717
Graphs = "1"
18-
MLDatasets = "0.6, 0.7"
19-
julia = "1.7"
18+
MLDatasets = "0.7"
19+
julia = "1.9"

examples/graph_classification_tudataset.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ function train(; kws...)
8181
GlobalPool(mean),
8282
Dense(nhidden, 1)) |> device
8383

84-
ps = Flux.params(model)
85-
opt = Adam(args.η)
84+
opt = Flux.setup(Adam(args.η), model)
8685

8786
# LOGGING FUNCTION
8887

@@ -98,11 +97,11 @@ function train(; kws...)
9897
for epoch in 1:(args.epochs)
9998
for (g, y) in train_loader
10099
g, y = (g, y) |> device
101-
gs = Flux.gradient(ps) do
100+
grads = Flux.gradient(model) do model
102101
= model(g, g.ndata.x) |> vec
103102
logitbinarycrossentropy(ŷ, y)
104103
end
105-
Flux.Optimise.update!(opt, ps, gs)
104+
Flux.update!(opt, model, grads[1])
106105
end
107106
epoch % args.infotime == 0 && report(epoch)
108107
end

examples/link_prediction_pubmed.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,11 @@ function train(; kws...)
7777

7878
pred = DotPredictor()
7979

80-
ps = Flux.params(model)
81-
opt = Adam(args.η)
80+
opt = Flux.setup(Adam(args.η), model)
8281

8382
### LOSS FUNCTION ############
8483

85-
function loss(pos_g, neg_g = nothing; with_accuracy = false)
84+
function loss(model, pos_g, neg_g = nothing; with_accuracy = false)
8685
h = model(X)
8786
if neg_g === nothing
8887
# We sample a negative graph at each training step
@@ -103,16 +102,16 @@ function train(; kws...)
103102

104103
### LOGGING FUNCTION
105104
function report(epoch)
106-
train_loss, train_acc = loss(train_pos_g, with_accuracy = true)
107-
test_loss, test_acc = loss(test_pos_g, test_neg_g, with_accuracy = true)
105+
train_loss, train_acc = loss(model, train_pos_g, with_accuracy = true)
106+
test_loss, test_acc = loss(model, test_pos_g, test_neg_g, with_accuracy = true)
108107
println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))")
109108
end
110109

111110
### TRAINING
112111
report(0)
113112
for epoch in 1:(args.epochs)
114-
gs = Flux.gradient(() -> loss(train_pos_g), ps)
115-
Flux.Optimise.update!(opt, ps, gs)
113+
grads = Flux.gradient(model -> loss(model, train_pos_g), model)
114+
Flux.update!(opt, model, grads[1])
116115
epoch % args.infotime == 0 && report(epoch)
117116
end
118117
end

examples/neurosat.jl

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
using GraphNeuralNetworks
2+
using Flux
3+
using Random, Statistics, LinearAlgebra
4+
using SparseArrays
5+
6+
"""
7+
A type representing a conjunctive normal form.
8+
"""
9+
struct CNF
10+
N::Int # num variables
11+
M::Int # num factors
12+
clauses::Vector{Vector{Int}}
13+
end
14+
15+
function CNF(clauses::Vector{Vector{Int}})
16+
M = length(clauses)
17+
N = maximum(maximum(abs.(c)) for c in clauses)
18+
return CNF(N, M, clauses)
19+
end
20+
21+
"""
22+
randomcnf(; N=100, k=3, α=0.1, seed=-1, planted = Vector{Vector{Int}}())
23+
24+
Generates a random instance of the k-SAT problem, with `N` variables and `αN` clauses.
25+
Any configuration in `planted` is guaranteed to be a solution of the problem.
26+
"""
27+
function randomcnf(; N::Int = 100, k::Int = 3, α::Float64 = 0.1, seed::Int=-1,
28+
planted = Vector{Vector{Int}}())
29+
seed > 0 && Random.seed!(seed)
30+
M = round(Int, N*α)
31+
clauses = Vector{Vector{Int}}()
32+
for p in planted
33+
@assert length(p) == N "Wrong size for planted configurations ($N != $(lenght(p)) )"
34+
end
35+
for a=1:M
36+
while true
37+
c = rand(1:N, k)
38+
length(union(c)) != k && continue
39+
c = c .* rand([-1,1], k)
40+
41+
# reject if not satisfies the planted solutions
42+
sat = Bool[any(i -> i>0, sol[abs.(c)] .* c) for sol in planted]
43+
!all(sat) && continue
44+
45+
push!(clauses, c)
46+
break
47+
end
48+
end
49+
return CNF(N, M, clauses)
50+
end
51+
52+
53+
function to_edge_index(cnf::CNF)
54+
N = cnf.N
55+
srcV, dstF = Vector{Int}(), Vector{Int}()
56+
srcF, dstV = Vector{Int}(), Vector{Int}()
57+
for (a, c) in enumerate(cnf.clauses)
58+
for v in c
59+
negated = v < 0
60+
push!(srcV, abs(v) + N*negated)
61+
push!(dstF, a)
62+
push!(srcF, a)
63+
push!(dstV, abs(v) + N*negated)
64+
end
65+
end
66+
return srcV, dstF,srcV, dstF
67+
end
68+
69+
function to_adjacency_matrix(cnf::CNF)
70+
M, N = cnf.M, cnf.N
71+
A = spzeros(Int, M, 2*N)
72+
for (a, c) in enumerate(cnf.clauses)
73+
for v in c
74+
negated = v < 0
75+
A[a, abs(v) + N*negated] = 1
76+
end
77+
end
78+
return A
79+
end
80+
81+
function flip_literals(X::AbstractMatrix)
82+
n = size(X, 2) ÷ 2
83+
return hcat(X[:,n+1:2n], X[:,1:n])
84+
end
85+
86+
## Layer
87+
struct NeuroSAT
88+
Xv0
89+
Xf0
90+
MLPv
91+
MLPf
92+
MLPout
93+
LSTMv
94+
LSTMf
95+
end
96+
97+
# A # rectangular adjacency matrix
98+
Flux.@functor NeuroSAT
99+
100+
# Optimisers.trainable(m::NeuroSAT) = (; m.MLPv, m.MLPf, m.MLPout, m.LSTMv, m.LSTMf)
101+
102+
function NeuroSAT(D::Int)
103+
Xv0 = randn(Float32, D)
104+
Xf0 = randn(Float32, D)
105+
MLPv = Chain(Dense(D => 4D, relu), Dense(4D => D))
106+
MLPf = Chain(Dense(D => 4D, relu), Dense(4D => D))
107+
MLPout = Chain(Dense(D => 4D, relu), Dense(4D => 1))
108+
LSTMv = LSTM(2D => D)
109+
LSTMf = LSTM(D => D)
110+
return NeuroSAT(Xv0, Xf0, MLPv, MLPf, MLPout, LSTMv, LSTMf)
111+
end
112+
113+
function (m::NeuroSAT)(A::AbstractArray, Tmax)
114+
Xv = repeat(m.Xv0, 1, size(A, 2))
115+
# Xf = repeat(m.Xf0, 1, size(A, 1))
116+
117+
for t = 1:Tmax
118+
Xv = m.MLPv(Xv)
119+
Mf = Xv * A'
120+
Xf = m.MLPf(m.LSTMf(Mf))
121+
Mv = Xf * A
122+
Xv = m.LSTMv(vcat(Mv, flip_literals(Xv)))
123+
end
124+
return mean(m.MLPout(Xv))
125+
end
126+
127+
# function Base.show(io, m::NeuroSAT)
128+
# D = size(m.Xv0, 1)
129+
# print(io, "NeuroSAT($(D))")
130+
# end
131+
132+
N = 100
133+
cnf = randomcnf(; N, k=3, α=1.5, seed=-1)
134+
M = cnf.M
135+
D = 32 # 128 nel paper
136+
Xv = randn(Float32, D, 2*N)
137+
Xf = randn(Float32, D, M)
138+
139+
srcV, dstF, srcF, dstV = to_edge_index(cnf)
140+
A = to_adjacency_matrix(cnf)
141+
142+
143+
model = NeuroSAT(D)
144+
145+
m_vtof = GNNGraphs._gather(Xv, srcV)
146+
m_ftov = GNNGraphs._gather(Xf, srcF)
147+
148+

0 commit comments

Comments
 (0)