Skip to content

Commit 7823f81

Browse files
author
Ryan Nett
committed
Custom listener and new training example (can't use validation b/c beta5 bugs)
Signed-off-by: Ryan Nett <[email protected]>
1 parent 890a132 commit 7823f81

File tree

3 files changed

+229
-0
lines changed

3 files changed

+229
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package org.deeplearning4j.examples.samediff.training;
2+
3+
import static org.deeplearning4j.examples.samediff.training.SameDiffMNISTTrainingExample.makeMNISTNet;
4+
5+
import java.util.Arrays;
6+
import java.util.List;
7+
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
8+
import org.nd4j.autodiff.listeners.At;
9+
import org.nd4j.autodiff.listeners.BaseListener;
10+
import org.nd4j.autodiff.listeners.ListenerVariables;
11+
import org.nd4j.autodiff.listeners.Operation;
12+
import org.nd4j.autodiff.listeners.impl.ScoreListener;
13+
import org.nd4j.autodiff.listeners.records.History;
14+
import org.nd4j.autodiff.samediff.SDVariable;
15+
import org.nd4j.autodiff.samediff.SameDiff;
16+
import org.nd4j.autodiff.samediff.TrainingConfig;
17+
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
18+
import org.nd4j.evaluation.classification.Evaluation;
19+
import org.nd4j.evaluation.classification.Evaluation.Metric;
20+
import org.nd4j.linalg.api.buffer.DataType;
21+
import org.nd4j.linalg.api.ndarray.INDArray;
22+
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
23+
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
24+
import org.nd4j.linalg.dataset.api.MultiDataSet;
25+
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
26+
import org.nd4j.linalg.learning.config.Adam;
27+
import org.nd4j.weightinit.impl.XavierInitScheme;
28+
29+
/**
30+
* This example shows how to use a custom listener, and is based on the {@link SameDiffMNISTTrainingExample}.
31+
*/
32+
public class SameDiffCustomListenerExample {
33+
34+
/**
35+
* A basic custom listener that records the values of z and out, for comparison later
36+
*/
37+
public static class CustomListener extends BaseListener {
38+
39+
public INDArray z;
40+
public INDArray out;
41+
42+
// Specify that this listener is active during inference operations
43+
@Override
44+
public boolean isActive(Operation operation) {
45+
return operation == Operation.INFERENCE;
46+
}
47+
48+
// Specify that this listener requires the activations of "z" and "out"
49+
@Override
50+
public ListenerVariables requiredVariables(SameDiff sd) {
51+
return new ListenerVariables.Builder().inferenceVariables("z", "out").build();
52+
}
53+
54+
// Called when the activation of a variable becomes available
55+
@Override
56+
public void activationAvailable(SameDiff sd, At at,
57+
MultiDataSet batch, SameDiffOp op,
58+
String varName, INDArray activation) {
59+
60+
// if the variable is z or out, store its activation
61+
if(varName.equals("out")){
62+
out = activation.detach().dup();
63+
} else if(varName.equals("z")){
64+
z = activation.detach().dup();
65+
}
66+
}
67+
}
68+
69+
70+
public static void main(String[] args) throws Exception {
71+
SameDiff sd = makeMNISTNet();
72+
73+
//Create and set the training configuration
74+
double learningRate = 1e-3;
75+
TrainingConfig config = new TrainingConfig.Builder()
76+
.l2(1e-4) //L2 regularization
77+
.updater(new Adam(learningRate)) //Adam optimizer with specified learning rate
78+
.dataSetFeatureMapping("input") //DataSet features array should be associated with variable "input"
79+
.dataSetLabelMapping("label") //DataSet label array should be associated with variable "label"
80+
.build();
81+
82+
sd.setTrainingConfig(config);
83+
84+
int batchSize = 32;
85+
DataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 12345);
86+
87+
//Perform training
88+
History hist = sd.fit()
89+
.train(trainData, 4)
90+
.exec();
91+
List<Double> acc = hist.trainingEval(Metric.ACCURACY);
92+
93+
System.out.println("Accuracy: " + acc);
94+
95+
CustomListener listener = new CustomListener();
96+
97+
sd.output()
98+
.data(new MnistDataSetIterator(10, 10, false, false, true, 12345))
99+
.output("out")
100+
.listeners(listener)
101+
.exec();
102+
103+
System.out.println("Z: " + listener.z);
104+
System.out.println("Out (softmax(z)): " + listener.out);
105+
}
106+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package org.deeplearning4j.examples.samediff.training;
2+
3+
import java.util.Arrays;
4+
import java.util.List;
5+
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
6+
import org.nd4j.autodiff.listeners.ListenerEvaluations;
7+
import org.nd4j.autodiff.listeners.impl.HistoryListener;
8+
import org.nd4j.autodiff.listeners.impl.ScoreListener;
9+
import org.nd4j.autodiff.listeners.records.History;
10+
import org.nd4j.autodiff.samediff.SDVariable;
11+
import org.nd4j.autodiff.samediff.SameDiff;
12+
import org.nd4j.autodiff.samediff.TrainingConfig;
13+
import org.nd4j.evaluation.classification.Evaluation;
14+
import org.nd4j.evaluation.classification.Evaluation.Metric;
15+
import org.nd4j.linalg.api.buffer.DataType;
16+
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
17+
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
18+
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
19+
import org.nd4j.linalg.learning.config.Adam;
20+
import org.nd4j.weightinit.impl.OneInitScheme;
21+
import org.nd4j.weightinit.impl.XavierInitScheme;
22+
23+
/**
24+
* This example shows the creation and training of a MNIST CNN network.
25+
*/
26+
public class SameDiffMNISTTrainingExample {
27+
28+
public static SameDiff makeMNISTNet(){
29+
SameDiff sd = SameDiff.create();
30+
31+
//Properties for MNIST dataset:
32+
int nIn = 28*28;
33+
int nOut = 10;
34+
35+
//Create input and label variables
36+
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, nIn); //Shape: [?, 784] - i.e., minibatch x 784 for MNIST
37+
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, nOut); //Shape: [?, 10] - i.e., minibatch x 10 for MNIST
38+
39+
SDVariable reshaped = in.reshape(-1, 1, 28, 28);
40+
41+
Pooling2DConfig poolConfig = Pooling2DConfig.builder().kH(2).kW(2).sH(2).sW(2).build();
42+
43+
Conv2DConfig convConfig = Conv2DConfig.builder().kH(3).kW(3).build();
44+
45+
// layer 1: Conv2D with a 3x3 kernel and 4 output channels
46+
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 28 * 28, 26 * 26 * 4), DataType.FLOAT, 3, 3, 1, 4);
47+
SDVariable b0 = sd.zero("b0", 4);
48+
49+
SDVariable conv1 = sd.cnn().conv2d(reshaped, w0, b0, convConfig);
50+
51+
// layer 2: MaxPooling2D with a 2x2 kernel and stride, and ReLU activation
52+
SDVariable pool1 = sd.cnn().maxPooling2d(conv1, poolConfig);
53+
54+
SDVariable relu1 = sd.nn().relu(pool1, 0);
55+
56+
// layer 3: Conv2D with a 3x3 kernel and 8 output channels
57+
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 13 * 13 * 4, 11 * 11 * 8), DataType.FLOAT, 3, 3, 4, 8);
58+
SDVariable b1 = sd.zero("b1", 8);
59+
60+
SDVariable conv2 = sd.cnn().conv2d(relu1, w1, b1, convConfig);
61+
62+
// layer 4: MaxPooling2D with a 2x2 kernel and stride, and ReLU activation
63+
SDVariable pool2 = sd.cnn().maxPooling2d(conv2, poolConfig);
64+
65+
SDVariable relu2 = sd.nn().relu(pool2, 0);
66+
67+
SDVariable flat = relu2.reshape(-1, 5 * 5 * 8);
68+
69+
// layer 5: Output layer on flattened input
70+
SDVariable wOut = sd.var("wOut", new XavierInitScheme('c', 5 * 5 * 8, 10), DataType.FLOAT, 5 * 5 * 8, 10);
71+
SDVariable bOut = sd.zero("bOut", 10);
72+
73+
SDVariable z = sd.nn().linear("z", flat, wOut, bOut);
74+
75+
// softmax crossentropy loss function
76+
SDVariable loss = sd.loss().softmaxCrossEntropy("loss", label, z);
77+
78+
SDVariable out = sd.nn().softmax("out", z, 1);
79+
80+
sd.setLossVariables(loss);
81+
82+
return sd;
83+
}
84+
85+
public static void main(String[] args) throws Exception {
86+
SameDiff sd = makeMNISTNet();
87+
88+
//Create and set the training configuration
89+
90+
Evaluation evaluation = new Evaluation();
91+
92+
double learningRate = 1e-3;
93+
TrainingConfig config = new TrainingConfig.Builder()
94+
.l2(1e-4) //L2 regularization
95+
.updater(new Adam(learningRate)) //Adam optimizer with specified learning rate
96+
.dataSetFeatureMapping("input") //DataSet features array should be associated with variable "input"
97+
.dataSetLabelMapping("label") //DataSet label array should be associated with variable "label"
98+
.trainEvaluation("out", 0, evaluation) // add a training evaluation
99+
.build();
100+
101+
// You can add validation evaluations as well, but they have some issues in beta5 and most likely won't work.
102+
// If you want to use them, use the SNAPSHOT build.
103+
104+
sd.setTrainingConfig(config);
105+
106+
// Adding a listener to the SameDiff instance is necessary because of a beta5 bug, and is not necessary in snapshots
107+
sd.addListeners(new ScoreListener(20));
108+
109+
int batchSize = 32;
110+
DataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 12345);
111+
112+
//Perform training for 4 epochs
113+
int numEpochs = 4;
114+
History hist = sd.fit()
115+
.train(trainData, numEpochs)
116+
.exec();
117+
List<Double> acc = hist.trainingEval(Metric.ACCURACY);
118+
119+
System.out.println("Accuracy: " + acc);
120+
}
121+
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/samediff/training/SameDiffTrainingExample.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ public static void main(String[] args) throws Exception {
6969
SDVariable diff = sd.f().squaredDifference(softmax, label);
7070
SDVariable lossMse = diff.mean();
7171

72+
sd.setLossVariables(lossMse);
73+
7274
//Create and set the training configuration
7375
double learningRate = 1e-3;
7476
TrainingConfig config = new TrainingConfig.Builder()

0 commit comments

Comments
 (0)