Skip to content

Commit 90a2971

Browse files
committed
Fixed evaluate method for binary classification
1 parent bc2935d commit 90a2971

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

src/main/java/net/echo/brain4j/model/impl/Sequential.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,12 @@ public Sequential compile(WeightInit weightInit, LossFunctions function, Optimiz
103103

104104
@Override
105105
public EvaluationResult evaluate(DataSet<DataRow> dataSet) {
106-
int classes = layers.getLast().getNeurons().size();
106+
int classes = dataSet.getData().getFirst().outputs().size();
107+
108+
// Binary classification
109+
if (classes == 1) {
110+
classes = 2;
111+
}
107112

108113
Map<Integer, Vector> classifications = new ConcurrentHashMap<>();
109114

@@ -133,6 +138,11 @@ private Thread makeEvaluation(List<DataRow> partition, Map<Integer, Vector> clas
133138
int predIndex = MLUtils.indexOfMaxValue(prediction);
134139
int targetIndex = MLUtils.indexOfMaxValue(row.outputs());
135140

141+
if (row.outputs().size() == 1) {
142+
predIndex = prediction.get(0) > 0.5 ? 1 : 0;
143+
targetIndex = (int) row.outputs().get(0);
144+
}
145+
136146
Vector predictions = classifications.get(targetIndex);
137147
predictions.set(predIndex, predictions.get(predIndex) + 1);
138148
}

src/main/java/net/echo/brain4j/training/evaluation/EvaluationResult.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public String confusionMatrix() {
7474
}
7575

7676
matrix.append("\n ");
77-
matrix.append("-".repeat(classes * 6)).append("\n");
77+
matrix.append("-".repeat(5 + classes * 5)).append("\n");
7878

7979
for (int i = 0; i < classes; i++) {
8080
StringBuilder text = new StringBuilder();

src/test/java/XorExample.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,15 @@ private void start() {
2323
System.out.println(model.getStats());
2424

2525
long start = System.nanoTime();
26-
27-
// Fit the model for 1000 epoches
28-
for (int i = 0; i < 100; i++) {
29-
model.fit(dataSet);
30-
}
31-
26+
model.fit(dataSet, 1000);
3227
double took = (System.nanoTime() - start) / 1e6;
3328

3429
System.out.println("Loss: " + model.loss(dataSet));
3530
System.out.println("Took: " + took + " ms");
3631

37-
for (DataRow row : dataSet) {
38-
Vector prediction = model.predict(row.inputs());
32+
var result = model.evaluate(dataSet);
3933

40-
System.out.println("Expected: " + row.outputs() + " -> Predicted: " + prediction.toString("%.3f"));
41-
}
34+
System.out.println(result.confusionMatrix());
4235
}
4336

4437
private Sequential getModel() {

0 commit comments

Comments
 (0)