4
4
import android .graphics .Canvas ;
5
5
import android .graphics .Color ;
6
6
import android .graphics .Paint ;
7
- import android .os .AsyncTask ;
8
7
import android .util .AttributeSet ;
9
8
import android .view .View ;
10
9
11
10
import androidx .annotation .Nullable ;
12
11
13
- import org .deeplearning4j .nn .conf .MultiLayerConfiguration ;
14
- import org .deeplearning4j .nn .conf .NeuralNetConfiguration ;
15
- import org .deeplearning4j .nn .conf .layers .DenseLayer ;
16
- import org .deeplearning4j .nn .conf .layers .OutputLayer ;
17
- import org .deeplearning4j .nn .multilayer .MultiLayerNetwork ;
18
- import org .deeplearning4j .nn .weights .WeightInit ;
19
- import org .deeplearning4j .optimize .listeners .ScoreIterationListener ;
20
- import org .nd4j .evaluation .classification .Evaluation ;
21
- import org .nd4j .linalg .activations .Activation ;
22
- import org .nd4j .linalg .api .buffer .DataType ;
23
12
import org .nd4j .linalg .api .ndarray .INDArray ;
24
- import org .nd4j .linalg .dataset .DataSet ;
25
- import org .nd4j .linalg .factory .Nd4j ;
26
- import org .nd4j .linalg .learning .config .Nesterovs ;
27
- import org .nd4j .linalg .lossfunctions .LossFunctions ;
28
13
29
- import java .io .BufferedReader ;
30
- import java .io .IOException ;
31
- import java .io .InputStreamReader ;
32
- import java .util .ArrayList ;
33
-
34
- public class ScatterView extends View {
14
+ public class ScatterView extends View implements OnTrainingUpdateEventListener {
35
15
36
16
private final Paint redPaint ;
37
17
private final Paint greenPaint ;
38
18
private final Paint lightGreenPaint ;
39
19
private final Paint lightRedPaint ;
40
- private float [][] data ;
41
- private DataSet ds ;
42
20
43
- private final int nPointsPerAxis = 100 ;
44
- private INDArray xyGrid ; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
45
- private INDArray modelOut = null ;
21
+
22
+ private INDArray modelOut = null ; // nn output for the grid.
23
+
24
+ private TrainingTask task ;
46
25
47
26
public ScatterView (Context context , @ Nullable AttributeSet attrs ) {
48
27
super (context , attrs );
49
- data = null ;
50
28
redPaint = new Paint ();
51
29
redPaint .setColor (Color .RED );
52
30
greenPaint = new Paint ();
@@ -57,16 +35,9 @@ public ScatterView(Context context, @Nullable AttributeSet attrs) {
57
35
lightRedPaint = new Paint ();
58
36
lightRedPaint .setColor (Color .rgb (255 , 153 , 152 ));
59
37
60
- AsyncTask .execute (() -> {
61
- try {
62
- calcGrid ();
63
- ReadCSV ();
64
- BuildNN ();
65
-
66
- } catch (IOException e ) {
67
- e .printStackTrace ();
68
- }
69
- });
38
+ task = new TrainingTask ();
39
+ task .setListener (this );
40
+ task .executeTask ();
70
41
}
71
42
72
43
@ Override
@@ -75,9 +46,10 @@ public void onDraw(Canvas canvas) {
75
46
int w = this .getWidth ();
76
47
77
48
//draw the nn predictions:
78
- if ((modelOut != null ) && (null != xyGrid )){
79
- int halfRectHeight = h / nPointsPerAxis ;
80
- int halfRectWidth = w / nPointsPerAxis ;
49
+ if (modelOut != null ) {
50
+ int halfRectHeight = h / task .getnPointsPerAxis ();
51
+ int halfRectWidth = w / task .getnPointsPerAxis ();
52
+ INDArray xyGrid = task .getXyGrid ();
81
53
int nRows = xyGrid .rows ();
82
54
83
55
for (int i = 0 ; i < nRows ; i ++){
@@ -91,9 +63,9 @@ public void onDraw(Canvas canvas) {
91
63
}
92
64
93
65
//draw the data set if we have one.
94
- if (null != data ) {
66
+ if (null != task . getData () ) {
95
67
96
- for (float [] datum : data ) {
68
+ for (float [] datum : task . getData () ) {
97
69
int x = (int ) (datum [1 ] * w );
98
70
int y = (int ) (datum [2 ] * h );
99
71
Paint p = (datum [0 ] == 0.0f ) ? redPaint : greenPaint ;
@@ -102,124 +74,11 @@ public void onDraw(Canvas canvas) {
102
74
}
103
75
}
104
76
105
- /**
106
- * this is not the regular way to read a csv file into a data set with DL4j.
107
- * In this example we have put the data in the assets folder so that the demo works offline.
108
- */
109
- private void ReadCSV () throws IOException {
110
- InputStreamReader is = new InputStreamReader (MainActivity .getInstance ().getApplicationContext ().getAssets ()
111
- .open ("linear_data_train.csv" ));
112
-
113
- BufferedReader reader = new BufferedReader (is );
114
- ArrayList <String > rawSVC = new ArrayList <>();
115
- String line ;
116
- while ((line = reader .readLine ()) != null ) {
117
- rawSVC .add (line );
118
- }
119
-
120
- float [][] tmpData = new float [rawSVC .size ()][3 ];
121
-
122
- int index = 0 ;
123
- for (String l : rawSVC ){
124
- String [] values = l .split ("," );
125
- for (int col = 0 ; col < 3L ; col ++){
126
- tmpData [index ][col ] = Float .parseFloat (values [col ]);
127
- }
128
-
129
- index ++;
130
- }
131
-
132
- normalizeColumn (1 , tmpData );
133
- normalizeColumn (2 , tmpData );
134
-
135
- this .data = tmpData ;
136
- INDArray arrData = Nd4j .createFromArray (tmpData );
137
- INDArray arrFeatures = arrData .getColumns (1 , 2 );
138
- INDArray c1 = arrData .getColumns (0 );
139
- INDArray c2 = c1 .mul (-1 ).addi (1.0 );
140
- INDArray labels = Nd4j .hstack (c1 , c2 );
141
- ds = new DataSet (arrFeatures , labels );
142
- }
143
-
144
- /**
145
- * Normalize the data in a given column. Normally one would use datavec.
146
- * @param c column to normalise.
147
- * @param tmpData java float array.
148
- */
149
- private void normalizeColumn (int c , float [][] tmpData ){
150
- int numPoints = tmpData .length ;
151
- float min = tmpData [0 ][c ];
152
- float max = tmpData [0 ][c ];
153
- for (float [] tmpDatum : tmpData ) {
154
- float x = tmpDatum [c ];
155
- if (x < min ) {
156
- min = x ;
157
- }
158
- if (x > max ) {
159
- max = x ;
160
- }
161
- }
162
-
163
- for (int i =0 ; i <numPoints ; i ++){
164
- float x = tmpData [i ][c ];
165
- tmpData [i ][c ] = (x - min ) / (max - min );
166
- }
167
- }
168
-
169
- private void BuildNN (){
170
- int seed = 123 ;
171
- double learningRate = 0.005 ;
172
- int numInputs = 2 ;
173
- int numOutputs = 2 ;
174
- int numHiddenNodes = 20 ;
175
- int nEpochs = 2000 ;
176
-
177
- MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
178
- .seed (seed )
179
- .weightInit (WeightInit .XAVIER )
180
- .updater (new Nesterovs (learningRate , 0.9 ))
181
- .list ()
182
- .layer (new DenseLayer .Builder ().nIn (numInputs ).nOut (numHiddenNodes )
183
- .activation (Activation .RELU )
184
- .build ())
185
- .layer (new OutputLayer .Builder (LossFunctions .LossFunction .NEGATIVELOGLIKELIHOOD )
186
- .activation (Activation .SOFTMAX )
187
- .nIn (numHiddenNodes ).nOut (numOutputs ).build ())
188
- .build ();
189
-
190
- MultiLayerNetwork model = new MultiLayerNetwork (conf );
191
- model .init ();
192
- model .setListeners (new ScoreIterationListener (10 ));
193
-
194
- for (int i = 0 ; i <nEpochs ; i ++){
195
- model .fit (ds );
196
- INDArray tmp = model .output (xyGrid );
197
-
198
- this .post (() -> {
199
- this .modelOut = tmp ; // update from within the UI thread.
200
- this .invalidate (); // have the UI thread redraw at its own convenience.
201
- });
202
- }
203
-
204
- Evaluation eval = new Evaluation (numOutputs );
205
- INDArray features = ds .getFeatures ();
206
- INDArray labels = ds .getLabels ();
207
- INDArray predicted = model .output (features ,false );
208
- eval .eval (labels , predicted );
209
- System .out .println (eval .stats ());
210
-
211
- this .invalidate ();
212
- }
213
- /**
214
- * The x,y grid to calculate the NN output. Only needs to be calculated once.
215
- */
216
- private void calcGrid (){
217
- // x coordinates of the pixels for the NN.
218
- INDArray xPixels = Nd4j .linspace (0 , 1.0 , nPointsPerAxis , DataType .DOUBLE );
219
- // y coordinates of the pixels for the NN.
220
- INDArray yPixels = Nd4j .linspace (0 , 1.0 , nPointsPerAxis , DataType .DOUBLE );
221
- //create the mesh:
222
- INDArray [] mesh = Nd4j .meshgrid (xPixels , yPixels );
223
- xyGrid = Nd4j .vstack (mesh [0 ].ravel (), mesh [1 ].ravel ()).transpose ();
77
+ @ Override
78
+ public void OnTrainingUpdate (INDArray modelOut ) {
79
+ this .post (() -> {
80
+ this .modelOut = modelOut ; // update from within the UI thread.
81
+ this .invalidate (); // have the UI thread redraw at its own convenience.
82
+ });
224
83
}
225
84
}
0 commit comments