Skip to content

Commit b16d243

Browse files
RobAltenatreo
authored andcommitted
Remove AsyncTask
Signed-off-by: Paul Dubs <[email protected]>
1 parent ccd5190 commit b16d243

File tree

4 files changed

+222
-161
lines changed

4 files changed

+222
-161
lines changed

android-examples/app/src/main/java/com/example/androidDl4jClassifier/MainActivity.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public boolean onOptionsItemSelected(MenuItem item) {
2323
switch (item.getItemId()) {
2424
case R.id.action_linear:
2525
// User chose linear dataset.
26+
ScatterView view = findViewById(R.id.id_scatterview);
2627
return true;
2728

2829
case R.id.action_moon:

android-examples/app/src/main/java/com/example/androidDl4jClassifier/ScatterView.java

Lines changed: 20 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,27 @@
44
import android.graphics.Canvas;
55
import android.graphics.Color;
66
import android.graphics.Paint;
7-
import android.os.AsyncTask;
87
import android.util.AttributeSet;
98
import android.view.View;
109

1110
import androidx.annotation.Nullable;
1211

13-
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
14-
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
15-
import org.deeplearning4j.nn.conf.layers.DenseLayer;
16-
import org.deeplearning4j.nn.conf.layers.OutputLayer;
17-
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
18-
import org.deeplearning4j.nn.weights.WeightInit;
19-
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
20-
import org.nd4j.evaluation.classification.Evaluation;
21-
import org.nd4j.linalg.activations.Activation;
22-
import org.nd4j.linalg.api.buffer.DataType;
2312
import org.nd4j.linalg.api.ndarray.INDArray;
24-
import org.nd4j.linalg.dataset.DataSet;
25-
import org.nd4j.linalg.factory.Nd4j;
26-
import org.nd4j.linalg.learning.config.Nesterovs;
27-
import org.nd4j.linalg.lossfunctions.LossFunctions;
2813

29-
import java.io.BufferedReader;
30-
import java.io.IOException;
31-
import java.io.InputStreamReader;
32-
import java.util.ArrayList;
33-
34-
public class ScatterView extends View {
14+
public class ScatterView extends View implements OnTrainingUpdateEventListener{
3515

3616
private final Paint redPaint;
3717
private final Paint greenPaint;
3818
private final Paint lightGreenPaint;
3919
private final Paint lightRedPaint;
40-
private float[][] data;
41-
private DataSet ds;
4220

43-
private final int nPointsPerAxis = 100;
44-
private INDArray xyGrid; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
45-
private INDArray modelOut = null;
21+
22+
private INDArray modelOut = null; // nn output for the grid.
23+
24+
private TrainingTask task;
4625

4726
public ScatterView(Context context, @Nullable AttributeSet attrs) {
4827
super(context, attrs);
49-
data = null;
5028
redPaint = new Paint();
5129
redPaint.setColor(Color.RED);
5230
greenPaint = new Paint();
@@ -57,16 +35,9 @@ public ScatterView(Context context, @Nullable AttributeSet attrs) {
5735
lightRedPaint = new Paint();
5836
lightRedPaint.setColor(Color.rgb(255, 153, 152));
5937

60-
AsyncTask.execute(() -> {
61-
try {
62-
calcGrid();
63-
ReadCSV();
64-
BuildNN();
65-
66-
} catch (IOException e) {
67-
e.printStackTrace();
68-
}
69-
});
38+
task = new TrainingTask();
39+
task.setListener(this);
40+
task.executeTask();
7041
}
7142

7243
@Override
@@ -75,9 +46,10 @@ public void onDraw(Canvas canvas) {
7546
int w = this.getWidth();
7647

7748
//draw the nn predictions:
78-
if ((modelOut != null) && (null != xyGrid )){
79-
int halfRectHeight = h / nPointsPerAxis;
80-
int halfRectWidth = w / nPointsPerAxis;
49+
if (modelOut != null) {
50+
int halfRectHeight = h / task.getnPointsPerAxis();
51+
int halfRectWidth = w / task.getnPointsPerAxis();
52+
INDArray xyGrid = task.getXyGrid();
8153
int nRows = xyGrid.rows();
8254

8355
for (int i = 0; i< nRows; i++){
@@ -91,9 +63,9 @@ public void onDraw(Canvas canvas) {
9163
}
9264

9365
//draw the data set if we have one.
94-
if (null != data) {
66+
if (null != task.getData()) {
9567

96-
for (float[] datum : data) {
68+
for (float[] datum : task.getData()) {
9769
int x = (int) (datum[1] * w);
9870
int y = (int) (datum[2] * h);
9971
Paint p = (datum[0] == 0.0f) ? redPaint : greenPaint;
@@ -102,124 +74,11 @@ public void onDraw(Canvas canvas) {
10274
}
10375
}
10476

105-
/**
106-
* this is not the regular way to read a csv file into a data set with DL4j.
107-
* In this example we have put the data in the assets folder so that the demo works offline.
108-
*/
109-
private void ReadCSV() throws IOException {
110-
InputStreamReader is = new InputStreamReader(MainActivity.getInstance().getApplicationContext().getAssets()
111-
.open("linear_data_train.csv"));
112-
113-
BufferedReader reader = new BufferedReader(is);
114-
ArrayList<String> rawSVC = new ArrayList<>();
115-
String line;
116-
while ((line = reader.readLine()) != null) {
117-
rawSVC.add(line);
118-
}
119-
120-
float[][] tmpData = new float[rawSVC.size()][3];
121-
122-
int index = 0;
123-
for(String l : rawSVC){
124-
String[] values = l.split(",");
125-
for(int col = 0; col< 3L; col++){
126-
tmpData[index][col] = Float.parseFloat(values[col]);
127-
}
128-
129-
index++;
130-
}
131-
132-
normalizeColumn(1, tmpData);
133-
normalizeColumn(2, tmpData);
134-
135-
this.data = tmpData;
136-
INDArray arrData = Nd4j.createFromArray(tmpData);
137-
INDArray arrFeatures = arrData.getColumns(1, 2);
138-
INDArray c1 = arrData.getColumns(0);
139-
INDArray c2 = c1.mul(-1).addi(1.0);
140-
INDArray labels = Nd4j.hstack(c1, c2);
141-
ds = new DataSet(arrFeatures, labels);
142-
}
143-
144-
/**
145-
* Normalize the data in a given column. Normally one would use datavec.
146-
* @param c column to normalise.
147-
* @param tmpData java float array.
148-
*/
149-
private void normalizeColumn(int c, float[][] tmpData){
150-
int numPoints = tmpData.length;
151-
float min= tmpData[0][c];
152-
float max= tmpData[0][c];
153-
for (float[] tmpDatum : tmpData) {
154-
float x = tmpDatum[c];
155-
if (x < min) {
156-
min = x;
157-
}
158-
if (x > max) {
159-
max = x;
160-
}
161-
}
162-
163-
for (int i=0; i<numPoints; i++){
164-
float x = tmpData[i][c];
165-
tmpData[i][c] = (x - min) / (max - min);
166-
}
167-
}
168-
169-
private void BuildNN(){
170-
int seed = 123;
171-
double learningRate = 0.005;
172-
int numInputs = 2;
173-
int numOutputs = 2;
174-
int numHiddenNodes = 20;
175-
int nEpochs = 2000;
176-
177-
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
178-
.seed(seed)
179-
.weightInit(WeightInit.XAVIER)
180-
.updater(new Nesterovs(learningRate, 0.9))
181-
.list()
182-
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
183-
.activation(Activation.RELU)
184-
.build())
185-
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
186-
.activation(Activation.SOFTMAX)
187-
.nIn(numHiddenNodes).nOut(numOutputs).build())
188-
.build();
189-
190-
MultiLayerNetwork model = new MultiLayerNetwork(conf);
191-
model.init();
192-
model.setListeners(new ScoreIterationListener(10));
193-
194-
for(int i = 0; i<nEpochs; i++){
195-
model.fit(ds);
196-
INDArray tmp = model.output(xyGrid);
197-
198-
this.post(() -> {
199-
this.modelOut = tmp; // update from within the UI thread.
200-
this.invalidate(); // have the UI thread redraw at its own convenience.
201-
});
202-
}
203-
204-
Evaluation eval = new Evaluation(numOutputs);
205-
INDArray features = ds.getFeatures();
206-
INDArray labels = ds.getLabels();
207-
INDArray predicted = model.output(features,false);
208-
eval.eval(labels, predicted);
209-
System.out.println(eval.stats());
210-
211-
this.invalidate();
212-
}
213-
/**
214-
* The x,y grid to calculate the NN output. Only needs to be calculated once.
215-
*/
216-
private void calcGrid(){
217-
// x coordinates of the pixels for the NN.
218-
INDArray xPixels = Nd4j.linspace(0, 1.0, nPointsPerAxis, DataType.DOUBLE);
219-
// y coordinates of the pixels for the NN.
220-
INDArray yPixels = Nd4j.linspace(0, 1.0, nPointsPerAxis, DataType.DOUBLE);
221-
//create the mesh:
222-
INDArray [] mesh = Nd4j.meshgrid(xPixels, yPixels);
223-
xyGrid = Nd4j.vstack(mesh[0].ravel(), mesh[1].ravel()).transpose();
77+
@Override
78+
public void OnTrainingUpdate(INDArray modelOut) {
79+
this.post(() -> {
80+
this.modelOut = modelOut; // update from within the UI thread.
81+
this.invalidate(); // have the UI thread redraw at its own convenience.
82+
});
22483
}
22584
}

0 commit comments

Comments
 (0)