Skip to content

Commit 91e1e76

Browse files
RobAltenatreo
authored andcommitted
adds animation during learning.
Signed-off-by: Paul Dubs <[email protected]>
1 parent 9d66e69 commit 91e1e76

File tree

1 file changed

+21
-20
lines changed
  • android-examples/app/src/main/java/com/example/androidDl4jClassifier

1 file changed

+21
-20
lines changed

android-examples/app/src/main/java/com/example/androidDl4jClassifier/ScatterView.java

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public class ScatterView extends View {
4242

4343
private final int nPointsPerAxis = 100;
4444
private INDArray xyGrid; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
45-
private MultiLayerNetwork model;
45+
INDArray modelOut = null;
4646

4747
public ScatterView(Context context, @Nullable AttributeSet attrs) {
4848
super(context, attrs);
@@ -74,30 +74,25 @@ public void onDraw(Canvas canvas) {
7474
int h = this.getHeight();
7575
int w = this.getWidth();
7676

77-
if (null == data) {
78-
canvas.drawColor(Color.rgb(32, 32, 32));
79-
canvas.drawCircle(800, 500, 200, redPaint);
80-
canvas.drawCircle(325, 900, 300, greenPaint);
81-
} else {
82-
83-
//draw the nn predictions:
77+
//draw the nn predictions:
78+
if ((modelOut != null) && (null != xyGrid )){
8479
int halfRectHeight = h / nPointsPerAxis;
8580
int halfRectWidth = w / nPointsPerAxis;
86-
INDArray modelOut = model.output(xyGrid);
87-
8881
int nRows = xyGrid.rows();
8982

9083
for (int i = 0; i< nRows; i++){
91-
int x = (int)(xyGrid.getFloat(i, 0) * w);
92-
int y = (int) (xyGrid.getFloat(i, 1) * h);
93-
float z = modelOut.getFloat(i, 0);
94-
Paint p = (z >= 0.5f) ? lightGreenPaint : lightRedPaint;
95-
canvas.drawRect(x-halfRectWidth, y-halfRectHeight, x+halfRectWidth, y+halfRectHeight, p);
96-
// }
84+
int x = (int)(xyGrid.getFloat(i, 0) * w);
85+
int y = (int) (xyGrid.getFloat(i, 1) * h);
86+
float z = modelOut.getFloat(i, 0);
87+
Paint p = (z >= 0.5f) ? lightGreenPaint : lightRedPaint;
88+
canvas.drawRect(x-halfRectWidth, y-halfRectHeight, x+halfRectWidth, y+halfRectHeight, p);
89+
// }
9790
}
91+
}
9892

93+
//draw the data set if we have one.
94+
if (null != data) {
9995

100-
//draw the data set
10196
for (float[] datum : data) {
10297
int x = (int) (datum[1] * w);
10398
int y = (int) (datum[2] * h);
@@ -173,11 +168,11 @@ private void normalizeColumn(int c, float[][] tmpData){
173168

174169
private void BuildNN(){
175170
int seed = 123;
176-
double learningRate = 0.01;
171+
double learningRate = 0.005;
177172
int numInputs = 2;
178173
int numOutputs = 2;
179174
int numHiddenNodes = 20;
180-
int nEpochs = 200;
175+
int nEpochs = 2000;
181176

182177
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
183178
.seed(seed)
@@ -192,12 +187,18 @@ private void BuildNN(){
192187
.nIn(numHiddenNodes).nOut(numOutputs).build())
193188
.build();
194189

195-
model = new MultiLayerNetwork(conf);
190+
MultiLayerNetwork model = new MultiLayerNetwork(conf);
196191
model.init();
197192
model.setListeners(new ScoreIterationListener(10));
198193

199194
for(int i = 0; i<nEpochs; i++){
200195
model.fit(ds);
196+
INDArray tmp = model.output(xyGrid);
197+
198+
this.post(() -> {
199+
this.modelOut = tmp; // update from within the UI thread.
200+
this.invalidate(); // have the UI thread redraw at its own convenience.
201+
});
201202
}
202203

203204
Evaluation eval = new Evaluation(numOutputs);

0 commit comments

Comments
 (0)