Skip to content

Commit 20ce84b

Browse files
committed
feat: fixed smart trainer
1 parent bd9ec3d commit 20ce84b

File tree

2 files changed

+52
-28
lines changed

2 files changed

+52
-28
lines changed

brain4j-core/src/main/java/org/brain4j/core/training/advanced/SmartTrainer.java

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,58 +33,63 @@ public void addListener(TrainListener listener) {
3333
public void abort() {
3434
this.running = false;
3535
}
36-
37-
public void start(Model model, ListDataSource dataSource, double lossThreshold) {
38-
this.start = System.nanoTime();
39-
this.running = true;
40-
this.epoches = 0;
41-
42-
this.listeners.forEach(listener -> listener.register(this, model));
43-
44-
while (running && evaluation.loss() > lossThreshold) {
45-
iterate(model, dataSource, Integer.MAX_VALUE);
46-
}
47-
48-
this.running = false;
49-
this.end = System.nanoTime();
50-
}
51-
36+
5237
public void step(Model model, ListDataSource dataSource, int totalEpoches) {
5338
long start = System.nanoTime();
5439
this.listeners.forEach(listener -> listener.onEpochStarted(epoches, totalEpoches, start));
55-
40+
5641
model.fit(dataSource);
57-
42+
5843
long took = System.nanoTime() - start;
5944
this.listeners.forEach(listener -> listener.onEpochCompleted(epoches, totalEpoches, took));
6045
}
61-
62-
public void startFor(Model model, ListDataSource dataSource, int epochesAmount) {
46+
47+
private void setupPreTraining(Model model, ListDataSource trainSource, ListDataSource evalSource) {
48+
if (trainSource == null || evalSource == null) {
49+
throw new IllegalArgumentException("Training source and evaluation source are required and cannot be null!");
50+
}
51+
6352
this.start = System.nanoTime();
6453
this.running = true;
6554
this.epoches = 0;
66-
55+
56+
this.evaluation = model.evaluate(evalSource);
6757
this.listeners.forEach(listener -> listener.register(this, model));
58+
}
59+
60+
public void start(Model model, ListDataSource trainSource, ListDataSource evalSource, double lossThreshold) {
61+
setupPreTraining(model, trainSource, evalSource);
62+
63+
while (running && evaluation.loss() > lossThreshold) {
64+
iterate(model, trainSource, evalSource, Integer.MAX_VALUE);
65+
}
6866

67+
this.running = false;
68+
this.end = System.nanoTime();
69+
}
70+
71+
public void startFor(Model model, ListDataSource trainSource, ListDataSource evalSource, int epochesAmount) {
72+
setupPreTraining(model, trainSource, evalSource);
73+
6974
for (int i = 0; i < epochesAmount && running; i++) {
70-
iterate(model, dataSource, epochesAmount);
75+
iterate(model, trainSource, evalSource, epochesAmount);
7176
}
7277

7378
this.running = false;
7479
this.end = System.nanoTime();
7580
}
7681

77-
private void iterate(Model model, ListDataSource dataSource, int totalEpoches) {
78-
step(model, dataSource, totalEpoches);
82+
private void iterate(Model model, ListDataSource trainSource, ListDataSource evalSource, int totalEpoches) {
83+
step(model, trainSource, totalEpoches);
7984

8085
this.epoches++;
8186

8287
if (epoches % evaluateEvery == 0) {
8388
long start = System.nanoTime();
84-
EvaluationResult result = model.evaluate(dataSource);
89+
this.evaluation = model.evaluate(evalSource);
8590
long took = System.nanoTime() - start;
8691

87-
this.listeners.forEach(listener -> listener.onEvaluated(dataSource, result, epoches, took));
92+
this.listeners.forEach(listener -> listener.onEvaluated(evalSource, evaluation, epoches, took));
8893
}
8994
}
9095

brain4j-core/src/main/java/org/brain4j/core/training/advanced/TrainListener.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,29 @@ public void onEpochCompleted(int epoch, int totalEpoches, long took) {
2727
*/
2828
public void onEpochStarted(int epoch, int totalEpoches, long start) {
2929
}
30-
30+
31+
/**
32+
* Called after an evaluation has completed.
33+
* @param dataSource the data source where the evaluation has been done
34+
* @param evaluation the evaluation result
35+
* @param epoch the current epoch
36+
* @param took how many nanoseconds it took to evaluate
37+
*/
3138
public void onEvaluated(ListDataSource dataSource, EvaluationResult evaluation, int epoch, long took) {
3239
}
33-
40+
41+
/**
42+
* Called when the loss gets increased
43+
* @param loss the current loss
44+
* @param previousLoss the previous loss
45+
*/
3446
public void onLossIncreased(double loss, double previousLoss) {
3547
}
48+
49+
/**
50+
* Aborts the training session, delegates to {@link SmartTrainer#abort()}
51+
*/
52+
public void abort() {
53+
trainer.abort();
54+
}
3655
}

0 commit comments

Comments
 (0)