@@ -42,7 +42,7 @@ public class ScatterView extends View {
42
42
43
43
private final int nPointsPerAxis = 100 ;
44
44
private INDArray xyGrid ; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
45
- private MultiLayerNetwork model ;
45
+ INDArray modelOut = null ;
46
46
47
47
public ScatterView (Context context , @ Nullable AttributeSet attrs ) {
48
48
super (context , attrs );
@@ -74,30 +74,25 @@ public void onDraw(Canvas canvas) {
74
74
int h = this .getHeight ();
75
75
int w = this .getWidth ();
76
76
77
- if (null == data ) {
78
- canvas .drawColor (Color .rgb (32 , 32 , 32 ));
79
- canvas .drawCircle (800 , 500 , 200 , redPaint );
80
- canvas .drawCircle (325 , 900 , 300 , greenPaint );
81
- } else {
82
-
83
- //draw the nn predictions:
77
+ //draw the nn predictions:
78
+ if ((modelOut != null ) && (null != xyGrid )){
84
79
int halfRectHeight = h / nPointsPerAxis ;
85
80
int halfRectWidth = w / nPointsPerAxis ;
86
- INDArray modelOut = model .output (xyGrid );
87
-
88
81
int nRows = xyGrid .rows ();
89
82
90
83
for (int i = 0 ; i < nRows ; i ++){
91
- int x = (int )(xyGrid .getFloat (i , 0 ) * w );
92
- int y = (int ) (xyGrid .getFloat (i , 1 ) * h );
93
- float z = modelOut .getFloat (i , 0 );
94
- Paint p = (z >= 0.5f ) ? lightGreenPaint : lightRedPaint ;
95
- canvas .drawRect (x -halfRectWidth , y -halfRectHeight , x +halfRectWidth , y +halfRectHeight , p );
96
- // }
84
+ int x = (int )(xyGrid .getFloat (i , 0 ) * w );
85
+ int y = (int ) (xyGrid .getFloat (i , 1 ) * h );
86
+ float z = modelOut .getFloat (i , 0 );
87
+ Paint p = (z >= 0.5f ) ? lightGreenPaint : lightRedPaint ;
88
+ canvas .drawRect (x -halfRectWidth , y -halfRectHeight , x +halfRectWidth , y +halfRectHeight , p );
89
+ // }
97
90
}
91
+ }
98
92
93
+ //draw the data set if we have one.
94
+ if (null != data ) {
99
95
100
- //draw the data set
101
96
for (float [] datum : data ) {
102
97
int x = (int ) (datum [1 ] * w );
103
98
int y = (int ) (datum [2 ] * h );
@@ -173,11 +168,11 @@ private void normalizeColumn(int c, float[][] tmpData){
173
168
174
169
private void BuildNN (){
175
170
int seed = 123 ;
176
- double learningRate = 0.01 ;
171
+ double learningRate = 0.005 ;
177
172
int numInputs = 2 ;
178
173
int numOutputs = 2 ;
179
174
int numHiddenNodes = 20 ;
180
- int nEpochs = 200 ;
175
+ int nEpochs = 2000 ;
181
176
182
177
MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
183
178
.seed (seed )
@@ -192,12 +187,18 @@ private void BuildNN(){
192
187
.nIn (numHiddenNodes ).nOut (numOutputs ).build ())
193
188
.build ();
194
189
195
- model = new MultiLayerNetwork (conf );
190
+ MultiLayerNetwork model = new MultiLayerNetwork (conf );
196
191
model .init ();
197
192
model .setListeners (new ScoreIterationListener (10 ));
198
193
199
194
for (int i = 0 ; i <nEpochs ; i ++){
200
195
model .fit (ds );
196
+ INDArray tmp = model .output (xyGrid );
197
+
198
+ this .post (() -> {
199
+ this .modelOut = tmp ; // update from within the UI thread.
200
+ this .invalidate (); // have the UI thread redraw at its own convenience.
201
+ });
201
202
}
202
203
203
204
Evaluation eval = new Evaluation (numOutputs );
0 commit comments