Skip to content

Commit 54b9b9b

Browse files
committed
Renamed evaluate to loss
1 parent af2f7df commit 54b9b9b

File tree

11 files changed

+51
-39
lines changed

11 files changed

+51
-39
lines changed

src/main/java/net/echo/brain4j/model/Model.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ public void connect(WeightInit weightInit, boolean update) {
117117
* @param set dataset for testing
118118
* @return the error of the model
119119
*/
120-
public abstract double evaluate(DataSet<R> set);
120+
public abstract double loss(DataSet<R> set);
121121

122122
/**
123123
* Trains the model for one epoch.
@@ -126,6 +126,18 @@ public void connect(WeightInit weightInit, boolean update) {
126126
*/
127127
public abstract void fit(DataSet<R> dataSet);
128128

129+
/**
130+
* Trains the model for the given number of epoches.
131+
*
132+
* @param dataSet dataset for training
133+
* @param epoches number of epoches
134+
*/
135+
public void fit(DataSet<R> dataSet, int epoches) {
136+
for (int i = 0; i < epoches; i++) {
137+
fit(dataSet);
138+
}
139+
}
140+
129141
/**
130142
* Predicts the output for the given input.
131143
*

src/main/java/net/echo/brain4j/model/impl/Sequential.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import net.echo.brain4j.loss.LossFunctions;
1313
import net.echo.brain4j.model.Model;
1414
import net.echo.brain4j.model.initialization.WeightInit;
15-
import net.echo.brain4j.structure.Neuron;
1615
import net.echo.brain4j.structure.Synapse;
1716
import net.echo.brain4j.structure.cache.StatesCache;
1817
import net.echo.brain4j.training.BackPropagation;
@@ -98,7 +97,7 @@ public Sequential compile(WeightInit weightInit, LossFunctions function, Optimiz
9897
}
9998

10099
@Override
101-
public double evaluate(DataSet<DataRow> set) {
100+
public double loss(DataSet<DataRow> set) {
102101
reloadMatrices();
103102

104103
AtomicReference<Double> totalError = new AtomicReference<>(0.0);

src/main/java/net/echo/brain4j/model/impl/Transformer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public void connect(WeightInit weightInit, boolean update) {
4242
}
4343

4444
@Override
45-
public double evaluate(DataSet<Object> set) {
45+
public double loss(DataSet<Object> set) {
4646
return 0;
4747
}
4848

src/main/java/net/echo/brain4j/training/techniques/SmartTrainer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public <R> void start(Model<R, ?, ?> model, DataSet<R> dataSet, double lossThres
4848

4949
if (epoches % evaluateEvery == 0) {
5050
long start = System.nanoTime();
51-
this.loss = model.evaluate(dataSet);
51+
this.loss = model.loss(dataSet);
5252
long took = System.nanoTime() - start;
5353

5454
this.listeners.forEach(listener -> listener.onEvaluated(dataSet, epoches, loss, took));
@@ -91,7 +91,7 @@ public <R> void startFor(Model<R, ?, ?> model, DataSet<R> dataSet, int epochesAm
9191

9292
if (epoches % evaluateEvery == 0) {
9393
long start = System.nanoTime();
94-
this.loss = model.evaluate(dataSet);
94+
this.loss = model.loss(dataSet);
9595
long took = System.nanoTime() - start;
9696

9797
this.listeners.forEach(listener -> listener.onEvaluated(dataSet, epoches, loss, took));

src/main/java/net/echo/brain4j/utils/MLUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public class MLUtils {
1717
* @param <T> the type of the enum
1818
* @return the best matching enum constant
1919
*/
20-
public static <T extends Enum<T>> T findBestMatch(Vector outputs, Class<T> clazz) {
20+
public static <T extends Enum<T>> T parse(Vector outputs, Class<T> clazz) {
2121
return clazz.getEnumConstants()[indexOfMaxValue(outputs)];
2222
}
2323

src/main/java/net/echo/brain4j/utils/Vector.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import java.util.Arrays;
44
import java.util.Iterator;
5+
import java.util.List;
56
import java.util.Random;
67
import java.util.function.Supplier;
78

@@ -68,6 +69,16 @@ public static Vector zeros(int size) {
6869
return new Vector(size).fill(0.0);
6970
}
7071

72+
public static Vector parse(List<String> pixels) {
73+
Vector vector = new Vector(pixels.size());
74+
75+
for (int i = 0; i < pixels.size(); i++) {
76+
vector.set(i, Double.parseDouble(pixels.get(i)));
77+
}
78+
79+
return vector;
80+
}
81+
7182
public void set(int index, double value) {
7283
data[index] = (float) value;
7384
}

src/test/java/ApproxExample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ public void trainAndVisualize() {
134134
updateGraph(x, trueOutput, predictedOutput);
135135
}
136136

137-
double error = model.evaluate(dataSet);
137+
double error = model.loss(dataSet);
138138
errorLabel.setText("Error: " + String.format("%.4f", error) + " Took: " + String.format("%.2f", took) + " ms Epoch: " + i);
139139
System.out.print("\rEpoch #" + j + " Error: " + error + " Took: " + took + " ms");
140140
}

src/test/java/RNNExample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public void start() {
2222
model.fit(dataSet);
2323

2424
if (i % 1000 == 0) {
25-
double loss = model.evaluate(dataSet);
25+
double loss = model.loss(dataSet);
2626

2727
System.out.println(i + ". Loss: " + loss);
2828
}

src/test/java/XorExample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ private void start() {
3131

3232
double took = (System.nanoTime() - start) / 1e6;
3333

34-
System.out.println("Loss: " + model.evaluate(dataSet));
34+
System.out.println("Loss: " + model.loss(dataSet));
3535
System.out.println("Took: " + took + " ms");
3636

3737
for (DataRow row : dataSet) {

src/test/java/conv/ConvExample.java

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,16 @@
77
import net.echo.brain4j.layer.impl.convolution.InputLayer;
88
import net.echo.brain4j.loss.LossFunctions;
99
import net.echo.brain4j.model.impl.Sequential;
10-
import net.echo.brain4j.model.initialization.WeightInit;
1110
import net.echo.brain4j.training.data.DataRow;
1211
import net.echo.brain4j.training.optimizers.impl.Adam;
13-
import net.echo.brain4j.training.techniques.SmartTrainer;
14-
import net.echo.brain4j.training.techniques.TrainListener;
15-
import net.echo.brain4j.training.updater.impl.StochasticUpdater;
1612
import net.echo.brain4j.utils.DataSet;
1713
import net.echo.brain4j.utils.MLUtils;
1814
import net.echo.brain4j.utils.Vector;
19-
import org.apache.commons.io.FileUtils;
15+
import org.apache.commons.csv.CSVFormat;
16+
import org.apache.commons.csv.CSVParser;
2017

21-
import java.io.File;
18+
import java.io.FileReader;
2219
import java.io.IOException;
23-
import java.util.Arrays;
2420
import java.util.List;
2521

2622
public class ConvExample {
@@ -36,15 +32,7 @@ private void start() throws IOException {
3632

3733
System.out.println(model.getStats());
3834

39-
SmartTrainer trainer = new SmartTrainer(1, 1);
40-
trainer.addListener(new TrainListener<DataRow>() {
41-
@Override
42-
public void onEvaluated(DataSet<DataRow> dataSet, int epoch, double loss, long took) {
43-
System.out.println("Epoch #" + epoch + " Loss: " + loss);
44-
}
45-
});
46-
trainer.startFor(model, dataSet, 100, 0.01);
47-
35+
model.fit(dataSet, 10);
4836
model.save("mnist-conv.json");
4937

5038
int incorrect = 0;
@@ -91,20 +79,22 @@ private Sequential getModel() {
9179

9280
private DataSet<DataRow> getDataSet() throws IOException {
9381
DataSet<DataRow> dataSet = new DataSet<>();
94-
List<String> lines = FileUtils.readLines(new File("dataset.csv"), "UTF-8");
9582

96-
for (int j = 0; j < 150 * 2; j++) {
97-
String line = lines.get(j);
98-
String[] parts = line.split(",");
99-
double[] inputs = Arrays.stream(parts, 1, parts.length).mapToDouble(x -> Double.parseDouble(x) / 255).toArray();
83+
FileReader reader = new FileReader("dataset.csv");
84+
CSVParser parser = new CSVParser(reader, CSVFormat.EXCEL);
10085

101-
Vector output = new Vector(10);
86+
parser.forEach(record -> {
87+
List<String> columns = record.toList();
10288

103-
int value = Integer.parseInt(parts[0]);
104-
output.set(value, 1);
89+
String label = columns.getFirst();
90+
List<String> pixels = columns.subList(1, columns.size());
10591

106-
dataSet.getData().add(new DataRow(Vector.of(inputs), output));
107-
}
92+
Vector output = new Vector(10);
93+
output.set(Integer.parseInt(label), 1);
94+
95+
Vector input = Vector.parse(pixels).divide(255);
96+
dataSet.add(new DataRow(input, output));
97+
});
10898

10999
return dataSet;
110100
}

0 commit comments

Comments
 (0)