Skip to content

Commit 9d66e69

Browse files
RobAltenatreo
authored andcommitted
First working version.
Signed-off-by: Paul Dubs <[email protected]>
1 parent 13967d9 commit 9d66e69

File tree

19 files changed

+277
-364
lines changed

19 files changed

+277
-364
lines changed

android-examples/.idea/misc.xml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

android-examples/app/build.gradle

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ android {
55
buildToolsVersion "29.0.3"
66

77
defaultConfig {
8-
applicationId "com.example.androidimageexperiment"
8+
applicationId "com.example.androidDl4jClassifier"
99
minSdkVersion 29
1010
targetSdkVersion 29
1111
versionCode 1
@@ -20,7 +20,10 @@ android {
2020
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
2121
}
2222
}
23-
23+
compileOptions {
24+
sourceCompatibility JavaVersion.VERSION_1_8
25+
targetCompatibility JavaVersion.VERSION_1_8
26+
}
2427
}
2528

2629
dependencies {
@@ -31,4 +34,35 @@ dependencies {
3134
testImplementation 'junit:junit:4.12'
3235
androidTestImplementation 'androidx.test.ext:junit:1.1.1'
3336
androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0'
37+
38+
implementation(group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '1.0.0-beta6') {
39+
exclude group: 'org.bytedeco', module: 'opencv-platform'
40+
exclude group: 'org.bytedeco', module: 'leptonica-platform'
41+
exclude group: 'org.bytedeco', module: 'hdf5-platform'
42+
exclude group: 'org.nd4j', module: 'nd4j-base64'
43+
}
44+
45+
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta6'
46+
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta6', classifier: "android-arm"
47+
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta6', classifier: "android-arm64"
48+
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta6', classifier: "android-x86"
49+
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta6', classifier: "android-x86_64"
50+
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2'
51+
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm"
52+
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-arm64"
53+
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86"
54+
implementation group: 'org.bytedeco', name: 'openblas', version: '0.3.7-1.5.2', classifier: "android-x86_64"
55+
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2'
56+
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm"
57+
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-arm64"
58+
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86"
59+
implementation group: 'org.bytedeco', name: 'opencv', version: '4.1.2-1.5.2', classifier: "android-x86_64"
60+
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2'
61+
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm"
62+
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-arm64"
63+
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86"
64+
implementation group: 'org.bytedeco', name: 'leptonica', version: '1.78.0-1.5.2', classifier: "android-x86_64"
65+
66+
annotationProcessor group: 'org.projectlombok', name: 'lombok', version: '1.18.4'
67+
annotationProcessor group: 'org.projectlombok', name: 'lombok', version: '1.18.4'
3468
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.example.androidimageexperiment;
1+
package com.example.androidDl4jClassifier;
22

33
import android.content.Context;
44

@@ -22,6 +22,6 @@ public void useAppContext() {
2222
// Context of the app under test.
2323
Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
2424

25-
assertEquals("com.example.androidimageexperiment", appContext.getPackageName());
25+
assertEquals("com.example.androidDl4jClassifier", appContext.getPackageName());
2626
}
2727
}

android-examples/app/src/main/AndroidManifest.xml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
<?xml version="1.0" encoding="utf-8"?>
22
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
3-
package="com.example.androidimageexperiment">
3+
package="com.example.androidDl4jClassifier">
44

55
<application
66
android:allowBackup="true"
77
android:icon="@mipmap/ic_launcher"
88
android:label="@string/app_name"
99
android:roundIcon="@mipmap/ic_launcher_round"
1010
android:supportsRtl="true"
11-
android:theme="@style/AppTheme">
12-
<activity android:name=".MainActivity">
11+
android:theme="@style/AppTheme"
12+
android:fullBackupContent="@xml/backup_descriptor">
13+
<activity android:name="com.example.androidDl4jClassifier.MainActivity">
1314
<intent-filter>
1415
<action android:name="android.intent.action.MAIN" />
1516

android-examples/app/src/main/java/com/example/androidimageexperiment/MainActivity.java renamed to android-examples/app/src/main/java/com/example/androidDl4jClassifier/MainActivity.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
1-
package com.example.androidimageexperiment;
1+
package com.example.androidDl4jClassifier;
22

33
import androidx.appcompat.app.AppCompatActivity;
44

5-
import android.os.AsyncTask;
65
import android.os.Bundle;
7-
import android.util.Log;
8-
9-
import java.io.BufferedReader;
10-
import java.io.IOException;
11-
import java.io.InputStreamReader;
12-
import java.util.ArrayList;
13-
import java.util.Arrays;
146

157
public class MainActivity extends AppCompatActivity {
168

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
package com.example.androidDl4jClassifier;
2+
3+
import android.content.Context;
4+
import android.graphics.Canvas;
5+
import android.graphics.Color;
6+
import android.graphics.Paint;
7+
import android.os.AsyncTask;
8+
import android.util.AttributeSet;
9+
import android.view.View;
10+
11+
import androidx.annotation.Nullable;
12+
13+
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
14+
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
15+
import org.deeplearning4j.nn.conf.layers.DenseLayer;
16+
import org.deeplearning4j.nn.conf.layers.OutputLayer;
17+
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
18+
import org.deeplearning4j.nn.weights.WeightInit;
19+
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
20+
import org.nd4j.evaluation.classification.Evaluation;
21+
import org.nd4j.linalg.activations.Activation;
22+
import org.nd4j.linalg.api.buffer.DataType;
23+
import org.nd4j.linalg.api.ndarray.INDArray;
24+
import org.nd4j.linalg.dataset.DataSet;
25+
import org.nd4j.linalg.factory.Nd4j;
26+
import org.nd4j.linalg.learning.config.Nesterovs;
27+
import org.nd4j.linalg.lossfunctions.LossFunctions;
28+
29+
import java.io.BufferedReader;
30+
import java.io.IOException;
31+
import java.io.InputStreamReader;
32+
import java.util.ArrayList;
33+
34+
public class ScatterView extends View {
35+
36+
private final Paint redPaint;
37+
private final Paint greenPaint;
38+
private final Paint lightGreenPaint;
39+
private final Paint lightRedPaint;
40+
private float[][] data;
41+
private DataSet ds;
42+
43+
private final int nPointsPerAxis = 100;
44+
private INDArray xyGrid; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
45+
private MultiLayerNetwork model;
46+
47+
public ScatterView(Context context, @Nullable AttributeSet attrs) {
48+
super(context, attrs);
49+
data = null;
50+
redPaint = new Paint();
51+
redPaint.setColor(Color.RED);
52+
greenPaint = new Paint();
53+
greenPaint.setColor(Color.GREEN);
54+
55+
lightGreenPaint = new Paint();
56+
lightGreenPaint.setColor(Color.rgb(225, 255, 225));
57+
lightRedPaint = new Paint();
58+
lightRedPaint.setColor(Color.rgb(255, 153, 152));
59+
60+
AsyncTask.execute(() -> {
61+
try {
62+
calcGrid();
63+
ReadCSV();
64+
BuildNN();
65+
66+
} catch (IOException e) {
67+
e.printStackTrace();
68+
}
69+
});
70+
}
71+
72+
@Override
73+
public void onDraw(Canvas canvas) {
74+
int h = this.getHeight();
75+
int w = this.getWidth();
76+
77+
if (null == data) {
78+
canvas.drawColor(Color.rgb(32, 32, 32));
79+
canvas.drawCircle(800, 500, 200, redPaint);
80+
canvas.drawCircle(325, 900, 300, greenPaint);
81+
} else {
82+
83+
//draw the nn predictions:
84+
int halfRectHeight = h / nPointsPerAxis;
85+
int halfRectWidth = w / nPointsPerAxis;
86+
INDArray modelOut = model.output(xyGrid);
87+
88+
int nRows = xyGrid.rows();
89+
90+
for (int i = 0; i< nRows; i++){
91+
int x = (int)(xyGrid.getFloat(i, 0) * w);
92+
int y = (int) (xyGrid.getFloat(i, 1) * h);
93+
float z = modelOut.getFloat(i, 0);
94+
Paint p = (z >= 0.5f) ? lightGreenPaint : lightRedPaint;
95+
canvas.drawRect(x-halfRectWidth, y-halfRectHeight, x+halfRectWidth, y+halfRectHeight, p);
96+
// }
97+
}
98+
99+
100+
//draw the data set
101+
for (float[] datum : data) {
102+
int x = (int) (datum[1] * w);
103+
int y = (int) (datum[2] * h);
104+
Paint p = (datum[0] == 0.0f) ? redPaint : greenPaint;
105+
canvas.drawCircle(x, y, 10, p);
106+
}
107+
}
108+
}
109+
110+
/**
111+
* this is not the regular way to read a csv file into a data set with DL4j.
112+
* In this example we have put the data in the assets folder so that the demo works offline.
113+
*/
114+
private void ReadCSV() throws IOException {
115+
InputStreamReader is = new InputStreamReader(MainActivity.getInstance().getApplicationContext().getAssets()
116+
.open("linear_data_train.csv"));
117+
118+
BufferedReader reader = new BufferedReader(is);
119+
ArrayList<String> rawSVC = new ArrayList<>();
120+
String line;
121+
while ((line = reader.readLine()) != null) {
122+
rawSVC.add(line);
123+
}
124+
125+
float[][] tmpData = new float[rawSVC.size()][3];
126+
127+
int index = 0;
128+
for(String l : rawSVC){
129+
String[] values = l.split(",");
130+
for(int col = 0; col< 3L; col++){
131+
tmpData[index][col] = Float.parseFloat(values[col]);
132+
}
133+
134+
index++;
135+
}
136+
137+
normalizeColumn(1, tmpData);
138+
normalizeColumn(2, tmpData);
139+
140+
this.data = tmpData;
141+
INDArray arrData = Nd4j.createFromArray(tmpData);
142+
INDArray arrFeatures = arrData.getColumns(1, 2);
143+
INDArray c1 = arrData.getColumns(0);
144+
INDArray c2 = c1.mul(-1).addi(1.0);
145+
INDArray labels = Nd4j.hstack(c1, c2);
146+
ds = new DataSet(arrFeatures, labels);
147+
}
148+
149+
/**
150+
* Normalize the data in a given column. Normally one would use datavec.
151+
* @param c column to normalise.
152+
* @param tmpData java float array.
153+
*/
154+
private void normalizeColumn(int c, float[][] tmpData){
155+
int numPoints = tmpData.length;
156+
float min= tmpData[0][c];
157+
float max= tmpData[0][c];
158+
for (float[] tmpDatum : tmpData) {
159+
float x = tmpDatum[c];
160+
if (x < min) {
161+
min = x;
162+
}
163+
if (x > max) {
164+
max = x;
165+
}
166+
}
167+
168+
for (int i=0; i<numPoints; i++){
169+
float x = tmpData[i][c];
170+
tmpData[i][c] = (x - min) / (max - min);
171+
}
172+
}
173+
174+
private void BuildNN(){
175+
int seed = 123;
176+
double learningRate = 0.01;
177+
int numInputs = 2;
178+
int numOutputs = 2;
179+
int numHiddenNodes = 20;
180+
int nEpochs = 200;
181+
182+
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
183+
.seed(seed)
184+
.weightInit(WeightInit.XAVIER)
185+
.updater(new Nesterovs(learningRate, 0.9))
186+
.list()
187+
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
188+
.activation(Activation.RELU)
189+
.build())
190+
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
191+
.activation(Activation.SOFTMAX)
192+
.nIn(numHiddenNodes).nOut(numOutputs).build())
193+
.build();
194+
195+
model = new MultiLayerNetwork(conf);
196+
model.init();
197+
model.setListeners(new ScoreIterationListener(10));
198+
199+
for(int i = 0; i<nEpochs; i++){
200+
model.fit(ds);
201+
}
202+
203+
Evaluation eval = new Evaluation(numOutputs);
204+
INDArray features = ds.getFeatures();
205+
INDArray labels = ds.getLabels();
206+
INDArray predicted = model.output(features,false);
207+
eval.eval(labels, predicted);
208+
System.out.println(eval.stats());
209+
210+
this.invalidate();
211+
}
212+
/**
213+
* The x,y grid to calculate the NN output. Only needs to be calculated once.
214+
*/
215+
private void calcGrid(){
216+
// x coordinates of the pixels for the NN.
217+
INDArray xPixels = Nd4j.linspace(0, 1.0, nPointsPerAxis, DataType.DOUBLE);
218+
// y coordinates of the pixels for the NN.
219+
INDArray yPixels = Nd4j.linspace(0, 1.0, nPointsPerAxis, DataType.DOUBLE);
220+
//create the mesh:
221+
INDArray [] mesh = Nd4j.meshgrid(xPixels, yPixels);
222+
xyGrid = Nd4j.vstack(mesh[0].ravel(), mesh[1].ravel()).transpose();
223+
}
224+
}

0 commit comments

Comments
 (0)