Skip to content

Commit 0222a36

Browse files
committed
wip until feedforward/classification.
Signed-off-by: Robert Altena <[email protected]>
1 parent 306dbe9 commit 0222a36

File tree

5 files changed

+20
-19
lines changed

5 files changed

+20
-19
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/classification/MLPClassifierLinear.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -49,6 +49,7 @@
4949
* @author Alex Black (added plots)
5050
*
5151
*/
52+
@SuppressWarnings("DuplicatedCode")
5253
public class MLPClassifierLinear {
5354

5455
public static String dataLocalPath;

dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/classification/MLPClassifierMoon.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -49,6 +49,7 @@
4949
* @author Alex Black (added plots)
5050
*
5151
*/
52+
@SuppressWarnings("DuplicatedCode")
5253
public class MLPClassifierMoon {
5354

5455
public static String dataLocalPath;
@@ -75,9 +76,6 @@ public static void main(String[] args) throws Exception {
7576
rrTest.initialize(new FileSplit(new File(dataLocalPath,"moon_data_eval.csv")));
7677
DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest,batchSize,0,2);
7778

78-
DataSet ds1 = trainIter.next();
79-
DataSet ds2 = testIter.next();
80-
8179
//log.info("Build model....");
8280
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
8381
.seed(seed)

dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/classification/MLPClassifierSaturn.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -20,7 +20,6 @@
2020
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
2121
import org.datavec.api.split.FileSplit;
2222
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
23-
import org.deeplearning4j.eval.Evaluation;
2423
import org.deeplearning4j.examples.download.DownloaderUtility;
2524
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
2625
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -29,6 +28,7 @@
2928
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
3029
import org.deeplearning4j.nn.weights.WeightInit;
3130
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
31+
import org.nd4j.evaluation.classification.Evaluation;
3232
import org.nd4j.linalg.activations.Activation;
3333
import org.nd4j.linalg.api.ndarray.INDArray;
3434
import org.nd4j.linalg.dataset.DataSet;
@@ -49,6 +49,7 @@
4949
* @author Alex Black (added plots)
5050
*
5151
*/
52+
@SuppressWarnings("DuplicatedCode")
5253
public class MLPClassifierSaturn {
5354

5455
public static String dataLocalPath;

dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/classification/PlotUtil.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -42,7 +42,7 @@
4242
/**Simple plotting methods for the MLPClassifier examples
4343
* @author Alex Black
4444
*/
45-
public class PlotUtil {
45+
class PlotUtil {
4646

4747
/**Plot the training data. Assume 2d input, classification output
4848
* @param features Training data features
@@ -51,7 +51,7 @@ public class PlotUtil {
5151
* @param backgroundOut results of network evaluation at points in x,y points in space
5252
* @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
5353
*/
54-
public static void plotTrainingData(INDArray features, INDArray labels, INDArray backgroundIn, INDArray backgroundOut, int nDivisions){
54+
static void plotTrainingData(INDArray features, INDArray labels, INDArray backgroundIn, INDArray backgroundOut, int nDivisions){
5555
double[] mins = backgroundIn.min(0).data().asDouble();
5656
double[] maxs = backgroundIn.max(0).data().asDouble();
5757

@@ -75,7 +75,7 @@ public static void plotTrainingData(INDArray features, INDArray labels, INDArray
7575
* @param backgroundOut results of network evaluation at points in x,y points in space
7676
* @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
7777
*/
78-
public static void plotTestData(INDArray features, INDArray labels, INDArray predicted, INDArray backgroundIn, INDArray backgroundOut, int nDivisions){
78+
static void plotTestData(INDArray features, INDArray labels, INDArray predicted, INDArray backgroundIn, INDArray backgroundOut, int nDivisions){
7979

8080
double[] mins = backgroundIn.min(0).data().asDouble();
8181
double[] maxs = backgroundIn.max(0).data().asDouble();
@@ -121,7 +121,7 @@ private static XYDataset createDataSetTrain(INDArray features, INDArray labels )
121121
int nClasses = 2; // Binary classification using one output call end sigmoid.
122122

123123
XYSeries[] series = new XYSeries[nClasses];
124-
for( int i=0; i<series.length; i++) series[i] = new XYSeries("Class " + String.valueOf(i));
124+
for( int i=0; i<series.length; i++) series[i] = new XYSeries("Class " + i);
125125
INDArray argMax = Nd4j.getExecutioner().exec(new IMax(labels, 1));
126126
for( int i=0; i<nRows; i++ ){
127127
int classIdx = (int)argMax.getDouble(i);

dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/classification/detectgender/PredictGenderTrain.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -16,14 +16,13 @@
1616

1717
package org.deeplearning4j.examples.feedforward.classification.detectgender;
1818

19-
/**
19+
/*
2020
* Created by KIT Solutions (www.kitsol.com) on 9/28/2016.
2121
*/
2222

2323
import org.datavec.api.split.FileSplit;
2424
import org.deeplearning4j.api.storage.StatsStorage;
2525
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
26-
import org.deeplearning4j.eval.Evaluation;
2726
import org.deeplearning4j.examples.download.DownloaderUtility;
2827
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
2928
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -34,6 +33,7 @@
3433
import org.deeplearning4j.ui.api.UIServer;
3534
import org.deeplearning4j.ui.stats.StatsListener;
3635
import org.deeplearning4j.ui.storage.FileStatsStorage;
36+
import org.nd4j.evaluation.classification.Evaluation;
3737
import org.nd4j.linalg.activations.Activation;
3838
import org.nd4j.linalg.api.ndarray.INDArray;
3939
import org.nd4j.linalg.dataset.DataSet;
@@ -44,13 +44,14 @@
4444
import java.io.File;
4545
import java.util.ArrayList;
4646

47+
@SuppressWarnings("DuplicatedCode")
4748
public class PredictGenderTrain
4849
{
4950
public String filePath;
5051
public static String dataLocalPath;
5152

5253

53-
public static void main(String args[]) throws Exception {
54+
public static void main(String[] args) throws Exception {
5455

5556
dataLocalPath = DownloaderUtility.PREDICTGENDERDATA.Download();
5657
PredictGenderTrain dg = new PredictGenderTrain();
@@ -68,9 +69,9 @@ public void train()
6869
double learningRate = 0.005;// was .01 but often got errors: "o.d.optimize.solvers.BaseOptimizer - Hit termination condition on iteration 0"
6970
int batchSize = 100;
7071
int nEpochs = 10;
71-
int numInputs = 0;
72-
int numOutputs = 0;
73-
int numHiddenNodes = 0;
72+
int numInputs;
73+
int numOutputs;
74+
int numHiddenNodes;
7475

7576
try(GenderRecordReader rr = new GenderRecordReader(new ArrayList<String>() {{add("M");add("F");}}))
7677
{

0 commit comments

Comments
 (0)