Skip to content

Commit 663f6e2

Browse files
committed
samediff/training.
Signed-off-by: Robert Altena <[email protected]>
1 parent e4be2ad commit 663f6e2

File tree

3 files changed

+13
-25
lines changed

3 files changed

+13
-25
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/samediff/training/SameDiffCustomListenerExample.java

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,21 @@
11
package org.deeplearning4j.examples.samediff.training;
22

3-
import static org.deeplearning4j.examples.samediff.training.SameDiffMNISTTrainingExample.makeMNISTNet;
4-
5-
import java.util.Arrays;
6-
import java.util.List;
73
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
84
import org.deeplearning4j.examples.samediff.tfimport.SameDiffTransferLearningExample;
9-
import org.nd4j.autodiff.listeners.At;
10-
import org.nd4j.autodiff.listeners.BaseEvaluationListener;
11-
import org.nd4j.autodiff.listeners.BaseListener;
12-
import org.nd4j.autodiff.listeners.Listener;
13-
import org.nd4j.autodiff.listeners.ListenerVariables;
14-
import org.nd4j.autodiff.listeners.Operation;
15-
import org.nd4j.autodiff.listeners.impl.ScoreListener;
5+
import org.nd4j.autodiff.listeners.*;
166
import org.nd4j.autodiff.listeners.records.History;
17-
import org.nd4j.autodiff.samediff.SDVariable;
187
import org.nd4j.autodiff.samediff.SameDiff;
198
import org.nd4j.autodiff.samediff.TrainingConfig;
209
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
21-
import org.nd4j.evaluation.classification.Evaluation;
2210
import org.nd4j.evaluation.classification.Evaluation.Metric;
23-
import org.nd4j.linalg.api.buffer.DataType;
2411
import org.nd4j.linalg.api.ndarray.INDArray;
25-
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
26-
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
2712
import org.nd4j.linalg.dataset.api.MultiDataSet;
2813
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
2914
import org.nd4j.linalg.learning.config.Adam;
30-
import org.nd4j.weightinit.impl.XavierInitScheme;
15+
16+
import java.util.List;
17+
18+
import static org.deeplearning4j.examples.samediff.training.SameDiffMNISTTrainingExample.makeMNISTNet;
3119

3220
/**
3321
* This example shows how to use a custom listener, and is based on the {@link SameDiffMNISTTrainingExample}.<br><br>
@@ -40,6 +28,7 @@
4028
* If you want to use evaluations in your listener, look at {@link BaseEvaluationListener}.
4129
*
4230
*/
31+
@SuppressWarnings("DuplicatedCode")
4332
public class SameDiffCustomListenerExample {
4433

4534
public static void main(String[] args) throws Exception {
@@ -84,7 +73,7 @@ public static void main(String[] args) throws Exception {
8473
*/
8574
public static class CustomListener extends BaseListener {
8675

87-
public INDArray z;
76+
INDArray z;
8877
public INDArray out;
8978

9079
// Specify that this listener is active during inference operations

dl4j-examples/src/main/java/org/deeplearning4j/examples/samediff/training/SameDiffMNISTTrainingExample.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
package org.deeplearning4j.examples.samediff.training;
22

3-
import java.util.Arrays;
4-
import java.util.List;
53
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
6-
import org.nd4j.autodiff.listeners.ListenerEvaluations;
7-
import org.nd4j.autodiff.listeners.impl.HistoryListener;
84
import org.nd4j.autodiff.listeners.impl.ScoreListener;
95
import org.nd4j.autodiff.listeners.records.History;
106
import org.nd4j.autodiff.samediff.SDVariable;
@@ -17,15 +13,16 @@
1713
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
1814
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
1915
import org.nd4j.linalg.learning.config.Adam;
20-
import org.nd4j.weightinit.impl.OneInitScheme;
2116
import org.nd4j.weightinit.impl.XavierInitScheme;
2217

18+
import java.util.List;
19+
2320
/**
2421
* This example shows the creation and training of a MNIST CNN network.
2522
*/
2623
public class SameDiffMNISTTrainingExample {
2724

28-
public static SameDiff makeMNISTNet(){
25+
static SameDiff makeMNISTNet(){
2926
SameDiff sd = SameDiff.create();
3027

3128
//Properties for MNIST dataset:
@@ -75,6 +72,7 @@ public static SameDiff makeMNISTNet(){
7572
// softmax crossentropy loss function
7673
SDVariable loss = sd.loss().softmaxCrossEntropy("loss", label, z);
7774

75+
//noinspection unused
7876
SDVariable out = sd.nn().softmax("out", z, 1);
7977

8078
sd.setLossVariables(loss);

dl4j-examples/src/main/java/org/deeplearning4j/examples/samediff/training/SameDiffTrainingExample.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -39,6 +39,7 @@
3939
*
4040
* @author Alex Black
4141
*/
42+
@SuppressWarnings("DuplicatedCode")
4243
public class SameDiffTrainingExample {
4344

4445
public static void main(String[] args) throws Exception {

0 commit comments

Comments
 (0)