Skip to content

Commit 6a3d441

Browse files
committed
Added an evaluation method
1 parent 54b9b9b commit 6a3d441

File tree

5 files changed

+115
-25
lines changed

5 files changed

+115
-25
lines changed

src/main/java/net/echo/brain4j/model/Model.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import net.echo.brain4j.structure.Synapse;
1919
import net.echo.brain4j.structure.cache.Parameters;
2020
import net.echo.brain4j.structure.cache.StatesCache;
21+
import net.echo.brain4j.training.evaluation.EvaluationResult;
2122
import net.echo.brain4j.training.optimizers.Optimizer;
2223
import net.echo.brain4j.training.optimizers.impl.Adam;
2324
import net.echo.brain4j.training.optimizers.impl.AdamW;
@@ -114,10 +115,10 @@ public void connect(WeightInit weightInit, boolean update) {
114115
/**
115116
* Evaluates the model on the given dataset.
116117
*
117-
* @param set dataset for testing
118+
* @param dataSet dataset for testing
118119
* @return the error of the model
119120
*/
120-
public abstract double loss(DataSet<R> set);
121+
public abstract double loss(DataSet<R> dataSet);
121122

122123
/**
123124
* Trains the model for one epoch.
@@ -146,6 +147,13 @@ public void fit(DataSet<R> dataSet, int epoches) {
146147
*/
147148
public abstract O predict(I input);
148149

150+
/**
151+
* Evaluates the model performance on the given dataset.
152+
* @param dataSet dataset to evaluate
153+
* @return an evaluation result
154+
*/
155+
public abstract EvaluationResult evaluate(DataSet<R> dataSet);
156+
149157
/**
150158
* Predicts output for given input.
151159
*

src/main/java/net/echo/brain4j/model/impl/Sequential.java

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@
1616
import net.echo.brain4j.structure.cache.StatesCache;
1717
import net.echo.brain4j.training.BackPropagation;
1818
import net.echo.brain4j.training.data.DataRow;
19+
import net.echo.brain4j.training.evaluation.EvaluationResult;
1920
import net.echo.brain4j.training.optimizers.Optimizer;
2021
import net.echo.brain4j.training.updater.Updater;
2122
import net.echo.brain4j.training.updater.impl.StochasticUpdater;
2223
import net.echo.brain4j.utils.DataSet;
24+
import net.echo.brain4j.utils.MLUtils;
2325
import net.echo.brain4j.utils.Vector;
2426

2527
import java.util.ArrayList;
28+
import java.util.HashMap;
2629
import java.util.List;
30+
import java.util.Map;
31+
import java.util.concurrent.ConcurrentHashMap;
2732
import java.util.concurrent.atomic.AtomicReference;
2833

2934
import static net.echo.brain4j.utils.MLUtils.waitAll;
@@ -97,18 +102,61 @@ public Sequential compile(WeightInit weightInit, LossFunctions function, Optimiz
97102
}
98103

99104
@Override
100-
public double loss(DataSet<DataRow> set) {
105+
public EvaluationResult evaluate(DataSet<DataRow> dataSet) {
106+
int classes = layers.getLast().getNeurons().size();
107+
108+
Map<Integer, Integer> correctlyClassified = new ConcurrentHashMap<>();
109+
Map<Integer, Integer> incorrectlyClassified = new ConcurrentHashMap<>();
110+
111+
for (int i = 0; i < classes; i++) {
112+
correctlyClassified.put(i, 0);
113+
incorrectlyClassified.put(i, 0);
114+
}
115+
116+
List<Thread> threads = new ArrayList<>();
117+
118+
if (!dataSet.isPartitioned()) {
119+
dataSet.partition(Math.min(Runtime.getRuntime().availableProcessors(), dataSet.getData().size()));
120+
}
121+
122+
for (List<DataRow> partition : dataSet.getPartitions()) {
123+
threads.add(makeEvaluation(partition, correctlyClassified, incorrectlyClassified));
124+
}
125+
126+
waitAll(threads);
127+
return new EvaluationResult(classes, correctlyClassified, incorrectlyClassified);
128+
}
129+
130+
private Thread makeEvaluation(List<DataRow> partition, Map<Integer, Integer> correctlyClassified, Map<Integer, Integer> incorrectlyClassified) {
131+
return Thread.startVirtualThread(() -> {
132+
for (DataRow row : partition) {
133+
Vector prediction = predict(row.inputs());
134+
135+
int predIndex = MLUtils.indexOfMaxValue(prediction);
136+
int targetIndex = MLUtils.indexOfMaxValue(row.outputs());
137+
138+
if (predIndex == targetIndex) {
139+
correctlyClassified.compute(targetIndex, (k, v) -> v == null ? 1 : v + 1);
140+
} else {
141+
incorrectlyClassified.compute(targetIndex, (k, v) -> v == null ? 1 : v + 1);
142+
}
143+
}
144+
});
145+
}
146+
147+
@Override
148+
public double loss(DataSet<DataRow> dataSet) {
101149
reloadMatrices();
102150

103151
AtomicReference<Double> totalError = new AtomicReference<>(0.0);
104152
List<Thread> threads = new ArrayList<>();
105153

106-
for (List<DataRow> partition : set.getPartitions()) {
154+
for (List<DataRow> partition : dataSet.getPartitions()) {
107155
threads.add(predictPartition(partition, totalError));
108156
}
109157

110158
waitAll(threads);
111-
return totalError.get() / set.size();
159+
return totalError.get() / dataSet.size();
112160
}
113161

114162
@Override

src/main/java/net/echo/brain4j/model/impl/Transformer.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import net.echo.brain4j.model.Model;
66
import net.echo.brain4j.model.initialization.WeightInit;
77
import net.echo.brain4j.structure.cache.StatesCache;
8+
import net.echo.brain4j.training.evaluation.EvaluationResult;
89
import net.echo.brain4j.training.optimizers.Optimizer;
910
import net.echo.brain4j.training.updater.Updater;
1011
import net.echo.brain4j.training.updater.impl.StochasticUpdater;
@@ -36,13 +37,18 @@ public Transformer compile(WeightInit weightInit, LossFunctions function, Optimi
3637
return this;
3738
}
3839

40+
@Override
41+
public EvaluationResult evaluate(DataSet<Object> dataSet) {
42+
return null;
43+
}
44+
3945
@Override
4046
public void connect(WeightInit weightInit, boolean update) {
4147
super.connect(weightInit, update);
4248
}
4349

4450
@Override
45-
public double loss(DataSet<Object> set) {
51+
public double loss(DataSet<Object> dataSet) {
4652
return 0;
4753
}
4854

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package net.echo.brain4j.training.evaluation;
2+
3+
import java.util.Map;
4+
5+
public record EvaluationResult(int classes, Map<Integer, Integer> correctlyClassified, Map<Integer, Integer> incorrectlyClassified) {
6+
7+
public String confusionMatrix() {
8+
StringBuilder matrix = new StringBuilder();
9+
String header = "======================================\n";
10+
11+
String pattern = "%-10s %-10s %-10s\n";
12+
matrix.append(String.format(pattern, "Classes", "Correct", "Incorrect"));
13+
matrix.append(header);
14+
15+
int totalCorrect = 0;
16+
int totalIncorrect = 0;
17+
18+
for (int i = 0; i < correctlyClassified.size(); i++) {
19+
totalCorrect += correctlyClassified.get(i);
20+
totalIncorrect += incorrectlyClassified.get(i);
21+
}
22+
23+
for (int i = 0; i < classes; i++) {
24+
matrix.append(String.format(pattern, i, correctlyClassified.get(i), incorrectlyClassified.get(i)));
25+
}
26+
27+
matrix.append(header);
28+
29+
double accuracy = totalCorrect / (double) (totalCorrect + totalIncorrect);
30+
double precision = totalCorrect / (double) (totalCorrect + incorrectlyClassified.get(0));
31+
double recall = totalCorrect / (double) (totalCorrect + correctlyClassified.get(0));
32+
33+
String secondary = "%-20s %-10s\n";
34+
matrix.append(String.format(secondary, "Accuracy:", String.format("%.2f", accuracy * 100), ""));
35+
matrix.append(String.format(secondary, "Precision:", String.format("%.2f", precision * 100), ""));
36+
matrix.append(String.format(secondary, "Recall: ", String.format("%.2f", recall * 100), ""));
37+
matrix.append(header);
38+
39+
return matrix.toString();
40+
}
41+
}

src/test/java/conv/ConvExample.java

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import net.echo.brain4j.loss.LossFunctions;
99
import net.echo.brain4j.model.impl.Sequential;
1010
import net.echo.brain4j.training.data.DataRow;
11+
import net.echo.brain4j.training.evaluation.EvaluationResult;
1112
import net.echo.brain4j.training.optimizers.impl.Adam;
1213
import net.echo.brain4j.utils.DataSet;
1314
import net.echo.brain4j.utils.MLUtils;
@@ -17,6 +18,7 @@
1718

1819
import java.io.FileReader;
1920
import java.io.IOException;
21+
import java.util.HashMap;
2022
import java.util.List;
2123

2224
public class ConvExample {
@@ -32,27 +34,12 @@ private void start() throws IOException {
3234

3335
System.out.println(model.getStats());
3436

35-
model.fit(dataSet, 10);
36-
model.save("mnist-conv.json");
37+
// model.fit(dataSet, 10);
38+
// model.save("mnist-conv.json");
3739

38-
int incorrect = 0;
40+
EvaluationResult result = model.evaluate(dataSet);
3941

40-
for (DataRow row : dataSet) {
41-
Vector input = row.inputs();
42-
Vector output = row.outputs();
43-
Vector prediction = model.predict(input);
44-
45-
int expected = MLUtils.indexOfMaxValue(output) + 1;
46-
int predicted = MLUtils.indexOfMaxValue(prediction) + 1;
47-
48-
if (expected != predicted) {
49-
System.out.println("Expected: " + expected + ", Prediction: " + predicted);
50-
incorrect++;
51-
}
52-
}
53-
54-
int correct = dataSet.size() - incorrect;
55-
System.out.println("Correct: " + correct + ", Incorrect: " + incorrect);
42+
System.out.println(result.confusionMatrix());
5643
}
5744

5845
private Sequential getModel() {

0 commit comments

Comments
 (0)