77import net .echo .brain4j .layer .impl .convolution .InputLayer ;
88import net .echo .brain4j .loss .LossFunctions ;
99import net .echo .brain4j .model .impl .Sequential ;
10- import net .echo .brain4j .model .initialization .WeightInit ;
1110import net .echo .brain4j .training .data .DataRow ;
1211import net .echo .brain4j .training .optimizers .impl .Adam ;
13- import net .echo .brain4j .training .techniques .SmartTrainer ;
14- import net .echo .brain4j .training .techniques .TrainListener ;
15- import net .echo .brain4j .training .updater .impl .StochasticUpdater ;
1612import net .echo .brain4j .utils .DataSet ;
1713import net .echo .brain4j .utils .MLUtils ;
1814import net .echo .brain4j .utils .Vector ;
19- import org .apache .commons .io .FileUtils ;
15+ import org .apache .commons .csv .CSVFormat ;
16+ import org .apache .commons .csv .CSVParser ;
2017
21- import java .io .File ;
18+ import java .io .FileReader ;
2219import java .io .IOException ;
23- import java .util .Arrays ;
2420import java .util .List ;
2521
2622public class ConvExample {
@@ -36,15 +32,7 @@ private void start() throws IOException {
3632
3733 System .out .println (model .getStats ());
3834
39- SmartTrainer trainer = new SmartTrainer (1 , 1 );
40- trainer .addListener (new TrainListener <DataRow >() {
41- @ Override
42- public void onEvaluated (DataSet <DataRow > dataSet , int epoch , double loss , long took ) {
43- System .out .println ("Epoch #" + epoch + " Loss: " + loss );
44- }
45- });
46- trainer .startFor (model , dataSet , 100 , 0.01 );
47-
35+ model .fit (dataSet , 10 );
4836 model .save ("mnist-conv.json" );
4937
5038 int incorrect = 0 ;
@@ -91,20 +79,22 @@ private Sequential getModel() {
9179
9280 private DataSet <DataRow > getDataSet () throws IOException {
9381 DataSet <DataRow > dataSet = new DataSet <>();
94- List <String > lines = FileUtils .readLines (new File ("dataset.csv" ), "UTF-8" );
9582
96- for (int j = 0 ; j < 150 * 2 ; j ++) {
97- String line = lines .get (j );
98- String [] parts = line .split ("," );
99- double [] inputs = Arrays .stream (parts , 1 , parts .length ).mapToDouble (x -> Double .parseDouble (x ) / 255 ).toArray ();
83+ FileReader reader = new FileReader ("dataset.csv" );
84+ CSVParser parser = new CSVParser (reader , CSVFormat .EXCEL );
10085
101- Vector output = new Vector (10 );
86+ parser .forEach (record -> {
87+ List <String > columns = record .toList ();
10288
103- int value = Integer . parseInt ( parts [ 0 ] );
104- output . set ( value , 1 );
89+ String label = columns . getFirst ( );
90+ List < String > pixels = columns . subList ( 1 , columns . size () );
10591
106- dataSet .getData ().add (new DataRow (Vector .of (inputs ), output ));
107- }
92+ Vector output = new Vector (10 );
93+ output .set (Integer .parseInt (label ), 1 );
94+
95+ Vector input = Vector .parse (pixels ).divide (255 );
96+ dataSet .add (new DataRow (input , output ));
97+ });
10898
10999 return dataSet ;
110100 }
0 commit comments