@@ -33,58 +33,63 @@ public void addListener(TrainListener listener) {
3333 public void abort () {
3434 this .running = false ;
3535 }
36-
37- public void start (Model model , ListDataSource dataSource , double lossThreshold ) {
38- this .start = System .nanoTime ();
39- this .running = true ;
40- this .epoches = 0 ;
41-
42- this .listeners .forEach (listener -> listener .register (this , model ));
43-
44- while (running && evaluation .loss () > lossThreshold ) {
45- iterate (model , dataSource , Integer .MAX_VALUE );
46- }
47-
48- this .running = false ;
49- this .end = System .nanoTime ();
50- }
51-
36+
5237 public void step (Model model , ListDataSource dataSource , int totalEpoches ) {
5338 long start = System .nanoTime ();
5439 this .listeners .forEach (listener -> listener .onEpochStarted (epoches , totalEpoches , start ));
55-
40+
5641 model .fit (dataSource );
57-
42+
5843 long took = System .nanoTime () - start ;
5944 this .listeners .forEach (listener -> listener .onEpochCompleted (epoches , totalEpoches , took ));
6045 }
61-
62- public void startFor (Model model , ListDataSource dataSource , int epochesAmount ) {
46+
47+ private void setupPreTraining (Model model , ListDataSource trainSource , ListDataSource evalSource ) {
48+ if (trainSource == null || evalSource == null ) {
49+ throw new IllegalArgumentException ("Training source and evaluation source are required and cannot be null!" );
50+ }
51+
6352 this .start = System .nanoTime ();
6453 this .running = true ;
6554 this .epoches = 0 ;
66-
55+
56+ this .evaluation = model .evaluate (evalSource );
6757 this .listeners .forEach (listener -> listener .register (this , model ));
58+ }
59+
60+ public void start (Model model , ListDataSource trainSource , ListDataSource evalSource , double lossThreshold ) {
61+ setupPreTraining (model , trainSource , evalSource );
62+
63+ while (running && evaluation .loss () > lossThreshold ) {
64+ iterate (model , trainSource , evalSource , Integer .MAX_VALUE );
65+ }
6866
67+ this .running = false ;
68+ this .end = System .nanoTime ();
69+ }
70+
71+ public void startFor (Model model , ListDataSource trainSource , ListDataSource evalSource , int epochesAmount ) {
72+ setupPreTraining (model , trainSource , evalSource );
73+
6974 for (int i = 0 ; i < epochesAmount && running ; i ++) {
70- iterate (model , dataSource , epochesAmount );
75+ iterate (model , trainSource , evalSource , epochesAmount );
7176 }
7277
7378 this .running = false ;
7479 this .end = System .nanoTime ();
7580 }
7681
77- private void iterate (Model model , ListDataSource dataSource , int totalEpoches ) {
78- step (model , dataSource , totalEpoches );
82+ private void iterate (Model model , ListDataSource trainSource , ListDataSource evalSource , int totalEpoches ) {
83+ step (model , trainSource , totalEpoches );
7984
8085 this .epoches ++;
8186
8287 if (epoches % evaluateEvery == 0 ) {
8388 long start = System .nanoTime ();
84- EvaluationResult result = model .evaluate (dataSource );
89+ this . evaluation = model .evaluate (evalSource );
8590 long took = System .nanoTime () - start ;
8691
87- this .listeners .forEach (listener -> listener .onEvaluated (dataSource , result , epoches , took ));
92+ this .listeners .forEach (listener -> listener .onEvaluated (evalSource , evaluation , epoches , took ));
8893 }
8994 }
9095
0 commit comments