Skip to content

Commit 7bbd997

Browse files
authored
Merge pull request #23 from SciML/nn
Add basic iris NN example
2 parents 7c987f6 + b9468a2 commit 7bbd997

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ version = "1.0.0-DEV"
66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
88
DiffEqGPU = "071ae1c0-96b5-11e9-1965-c90190d839ea"
9+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1112
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
13+
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
1214
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1315

1416
[compat]

examples/neural_network/nn.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using SimpleChains
2+
using IterTools
3+
using MLDatasets
4+
using Random
5+
dataset = MLDatasets.Iris().dataframe
6+
7+
data = Array(dataset)
8+
data = data[shuffle(1:end), :]
9+
10+
function mapstrtoclass(flower)
11+
if string(flower) == "Iris-setosa"
12+
return UInt32(1)
13+
elseif string(flower) == "Iris-versicolor"
14+
return UInt32(2)
15+
elseif string(flower) == "Iris-virginica"
16+
return UInt32(3)
17+
end
18+
end
19+
ytrain = map(mapstrtoclass, data[:, 5])
20+
lenet = SimpleChain(
21+
static(4),
22+
TurboDense{true}(tanh, 20),
23+
TurboDense{true}(identity, 3),
24+
)
25+
lenetloss = SimpleChains.add_loss(lenet, LogitCrossEntropyLoss(ytrain))
26+
27+
p = SimpleChains.init_params(lenet);
28+
xtrain = Float32.(Array(data[:, 1:4]'))
29+
G = SimpleChains.alloc_threaded_grad(lenet);
30+
31+
lenetloss(xtrain, p)
32+
33+
report = let mlpdloss = lenetloss, X=xtrain
34+
p -> begin
35+
let train = mlpdloss(X, p)
36+
@info "Loss:" train
37+
end
38+
end
39+
end
40+
41+
for _ in 1:3
42+
@time SimpleChains.train_unbatched!(
43+
G, p, lenetloss, xtrain, SimpleChains.ADAM(), 5000
44+
);
45+
report(p)
46+
end
47+
48+
p = SimpleChains.init_params(lenet);
49+
50+
lenetloss(xtrain, p)
51+
52+
using Optimization, PSOGPU
53+
54+
lb = -ones(length(p)) .* 10
55+
ub = ones(length(p)) .* 10
56+
prob = OptimizationProblem((u,data) -> lenetloss(data, u), p, xtrain; lb = lb, ub = ub)
57+
58+
n_particles = 1000
59+
60+
sol = solve(prob,
61+
ParallelPSOKernel(n_particles; gpu = false, threaded = true),
62+
maxiters = 1000)

0 commit comments

Comments
 (0)