Skip to content

Commit ec9ab1b

Browse files
authored
Merge pull request #907 from RobAltena/master
Update examples
2 parents 388fda0 + 911e5f2 commit ec9ab1b

File tree

76 files changed

+509
-637
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+509
-637
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,4 @@ Word2vec-index/
6161
!tutorials/*.json
6262
*end.model
6363

64+
arbiterExample/

dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/BasicHyperparameterOptimizationExample.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -42,11 +42,11 @@
4242
import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;
4343
import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener;
4444
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
45-
import org.deeplearning4j.eval.Evaluation;
4645
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
4746
import org.deeplearning4j.nn.weights.WeightInit;
4847
import org.deeplearning4j.ui.api.UIServer;
4948
import org.deeplearning4j.ui.storage.FileStatsStorage;
49+
import org.nd4j.evaluation.classification.Evaluation.Metric;
5050
import org.nd4j.linalg.activations.Activation;
5151
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
5252
import org.nd4j.linalg.lossfunctions.LossFunctions;
@@ -114,14 +114,16 @@ public static void main(String[] args) throws Exception {
114114
// This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ...
115115
String baseSaveDirectory = "arbiterExample/";
116116
File f = new File(baseSaveDirectory);
117-
if (f.exists()) f.delete();
117+
if (f.exists()) //noinspection ResultOfMethodCallIgnored
118+
f.delete();
119+
//noinspection ResultOfMethodCallIgnored
118120
f.mkdir();
119121
ResultSaver modelSaver = new FileModelSaver(baseSaveDirectory);
120122

121123
// (d) What are we actually trying to optimize?
122124
// In this example, let's use classification accuracy on the test set
123125
// See also ScoreFunctions.testSetF1(), ScoreFunctions.testSetRegression(regressionValue) etc
124-
ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.ACCURACY);
126+
ScoreFunction scoreFunction = new EvaluationScoreFunction(Metric.ACCURACY);
125127

126128

127129
// (e) When should we stop searching? Specify this with termination conditions

dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/BaseGeneticHyperparameterOptimizationExample.java

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -25,9 +25,10 @@
2525
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
2626
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
2727
import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction;
28-
import org.deeplearning4j.eval.Evaluation;
2928
import org.deeplearning4j.nn.graph.ComputationGraph;
29+
import org.nd4j.evaluation.classification.Evaluation.Metric;
3030

31+
import java.io.IOException;
3132
import java.util.List;
3233

3334
/**
@@ -38,23 +39,13 @@
3839
*
3940
* @author Alexandre Boulanger
4041
*/
41-
4242
public class BaseGeneticHyperparameterOptimizationExample {
4343

44-
public static void main(String[] args) throws Exception {
45-
46-
ComputationGraphSpace cgs = GeneticSearchExampleConfiguration.GetGraphConfiguration();
47-
48-
EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.F1);
49-
50-
// This is where we create the GeneticSearchCandidateGenerator with its default behavior:
51-
// - a population that fits 30 candidates and is culled back to 20 when it overflows
52-
// - new candidates are generated with a probability of 85% of being the result of breeding (a k-point crossover with 1 to 4 points)
53-
// - the new candidate have a probability of 0.5% of sustaining a random mutation on one of its genes.
54-
GeneticSearchCandidateGenerator candidateGenerator = new GeneticSearchCandidateGenerator.Builder(cgs, scoreFunction).build();
55-
56-
// Let's have a listener to print the population size after each evaluation.
57-
PopulationModel populationModel = candidateGenerator.getPopulationModel();
44+
/**
45+
* Common code used by two Arbiter examples.
46+
*/
47+
public static void run(PopulationModel populationModel, GeneticSearchCandidateGenerator candidateGenerator,
48+
EvaluationScoreFunction scoreFunction) throws IOException {
5849
populationModel.addListener(new ExamplePopulationListener());
5950

6051
IOptimizationRunner runner = GeneticSearchExampleConfiguration.BuildRunner(candidateGenerator, scoreFunction);
@@ -80,13 +71,32 @@ public static void main(String[] args) throws Exception {
8071
System.out.println(bestModel.getConfiguration().toJson());
8172
}
8273

74+
public static void main(String[] args) throws Exception {
75+
76+
ComputationGraphSpace cgs = GeneticSearchExampleConfiguration.GetGraphConfiguration();
77+
78+
EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Metric.F1);
79+
80+
// This is where we create the GeneticSearchCandidateGenerator with its default behavior:
81+
// - a population that fits 30 candidates and is culled back to 20 when it overflows
82+
// - new candidates are generated with a probability of 85% of being the result of breeding (a k-point crossover with 1 to 4 points)
83+
// - the new candidate have a probability of 0.5% of sustaining a random mutation on one of its genes.
84+
GeneticSearchCandidateGenerator candidateGenerator = new GeneticSearchCandidateGenerator.Builder(cgs, scoreFunction).build();
85+
86+
// Let's have a listener to print the population size after each evaluation.
87+
PopulationModel populationModel = candidateGenerator.getPopulationModel();
88+
populationModel.addListener(new ExamplePopulationListener());
89+
run(populationModel, candidateGenerator, scoreFunction);
90+
}
91+
8392
public static class ExamplePopulationListener implements PopulationListener {
8493

94+
@SuppressWarnings("OptionalGetWithoutIsPresent")
8595
@Override
8696
public void onChanged(List<Chromosome> population) {
8797
double best = population.get(0).getFitness();
8898
double average = population.stream()
89-
.mapToDouble(c -> c.getFitness())
99+
.mapToDouble(Chromosome::getFitness)
90100
.average()
91101
.getAsDouble();
92102
System.out.println(String.format("\nPopulation size is %1$s, best score is %2$s, average score is %3$s", population.size(), best, average));

dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/CustomGeneticHyperparameterOptimizationExample.java

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -20,25 +20,19 @@
2020
import org.apache.commons.math3.random.RandomGenerator;
2121
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
2222
import org.deeplearning4j.arbiter.ComputationGraphSpace;
23-
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
24-
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
2523
import org.deeplearning4j.arbiter.optimize.generator.GeneticSearchCandidateGenerator;
26-
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
2724
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
2825
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
2926
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
3027
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
3128
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
32-
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationListener;
3329
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
3430
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator;
3531
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
36-
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
3732
import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction;
38-
import org.deeplearning4j.eval.Evaluation;
39-
import org.deeplearning4j.nn.graph.ComputationGraph;
33+
import org.nd4j.evaluation.classification.Evaluation.Metric;
4034

41-
import java.util.List;
35+
import static org.deeplearning4j.examples.arbiter.genetic.BaseGeneticHyperparameterOptimizationExample.run;
4236

4337
/**
4438
* In this hyperparameter optimization example, we change the default behavior of the genetic candidate generator.
@@ -55,7 +49,7 @@ public static void main(String[] args) throws Exception {
5549

5650
ComputationGraphSpace cgs = GeneticSearchExampleConfiguration.GetGraphConfiguration();
5751

58-
EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.F1);
52+
EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Metric.F1);
5953

6054
// The ExampleCullOperator extends the default cull operator (least fit) to include an artificial predator.
6155
CullOperator cullOperator = new ExampleCullOperator();
@@ -85,30 +79,8 @@ public static void main(String[] args) throws Exception {
8579
.build();
8680

8781
// Let's have a listener to print the population size after each evaluation.
88-
populationModel.addListener(new ExamplePopulationListener());
89-
90-
IOptimizationRunner runner = GeneticSearchExampleConfiguration.BuildRunner(candidateGenerator, scoreFunction);
91-
92-
//Start the hyperparameter optimization
93-
runner.execute();
94-
95-
//Print out some basic stats regarding the optimization procedure
96-
String s = "Best score: " + runner.bestScore() + "\n" +
97-
"Index of model with best score: " + runner.bestScoreCandidateIndex() + "\n" +
98-
"Number of configurations evaluated: " + runner.numCandidatesCompleted() + "\n";
99-
System.out.println(s);
100-
101-
102-
//Get all results, and print out details of the best result:
103-
int indexOfBestResult = runner.bestScoreCandidateIndex();
104-
List<ResultReference> allResults = runner.getResults();
105-
106-
OptimizationResult bestResult = allResults.get(indexOfBestResult).getResult();
107-
ComputationGraph bestModel = (ComputationGraph) bestResult.getResultReference().getResultModel();
108-
109-
System.out.println("\n\nConfiguration of best model:\n");
110-
System.out.println(bestModel.getConfiguration().toJson());
111-
82+
populationModel.addListener(new BaseGeneticHyperparameterOptimizationExample.ExamplePopulationListener());
83+
run(populationModel, candidateGenerator, scoreFunction);
11284
}
11385

11486
// This is an example of a custom behavior for the genetic algorithm. We force one of the parent to be one of the
@@ -158,17 +130,4 @@ public void cullPopulation() {
158130
System.out.println(String.format("Randomly removed %1$s candidate(s).", preyCount));
159131
}
160132
}
161-
162-
public static class ExamplePopulationListener implements PopulationListener {
163-
164-
@Override
165-
public void onChanged(List<Chromosome> population) {
166-
double best = population.get(0).getFitness();
167-
double average = population.stream()
168-
.mapToDouble(Chromosome::getFitness)
169-
.average()
170-
.getAsDouble();
171-
System.out.println(String.format("\nPopulation size is %1$s, best score is %2$s, average score is %3$s", population.size(), best, average));
172-
}
173-
}
174133
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/GeneticSearchExampleConfiguration.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -49,22 +49,20 @@
4949
import java.io.File;
5050
import java.util.Properties;
5151

52-
public class GeneticSearchExampleConfiguration {
52+
class GeneticSearchExampleConfiguration {
5353

54-
public static ComputationGraphSpace GetGraphConfiguration() {
54+
static ComputationGraphSpace GetGraphConfiguration() {
5555
int inputSize = 784;
5656
int outputSize = 47;
5757

5858
// First, we setup the hyperspace parameters. These are the values which will change, breed and mutate
5959
// while attempting to find the best candidate.
60-
DiscreteParameterSpace<Activation> activationSpace = new DiscreteParameterSpace(new Activation[] {
61-
Activation.ELU,
60+
DiscreteParameterSpace<Activation> activationSpace = new DiscreteParameterSpace<>(Activation.ELU,
6261
Activation.RELU,
6362
Activation.LEAKYRELU,
6463
Activation.TANH,
6564
Activation.SELU,
66-
Activation.HARDSIGMOID
67-
});
65+
Activation.HARDSIGMOID);
6866
IntegerParameterSpace[] layersParametersSpace = new IntegerParameterSpace[] {
6967
new IntegerParameterSpace(outputSize, inputSize),
7068
new IntegerParameterSpace(outputSize, inputSize),

0 commit comments

Comments
 (0)