Skip to content

Commit c0ad5e8

Browse files
committed
code review items.
Signed-off-by: Robert Altena <[email protected]>
1 parent f503fed commit c0ad5e8

File tree

1 file changed

+88
-63
lines changed
  • dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples

1 file changed

+88
-63
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples/ImageDrawer.java

Lines changed: 88 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,21 @@
3434
import org.nd4j.linalg.indexing.conditions.Conditions;
3535
import org.nd4j.linalg.learning.config.Adam;
3636
import org.nd4j.linalg.lossfunctions.LossFunctions;
37+
import org.nd4j.shade.guava.collect.Streams;
3738

3839
import javax.imageio.ImageIO;
3940
import javax.swing.*;
4041
import java.awt.image.BufferedImage;
42+
import java.awt.image.DataBufferByte;
4143
import java.io.File;
44+
import java.util.Random;
4245

4346
/**
4447
* Application to show a neural network learning to draw an image.
4548
* Demonstrates how to feed an NN with externally originated data.
4649
*
4750
* Updates from previous versions:
4851
* - Now uses swing. No longer uses JavaFX which caused problems with the OpenJDK.
49-
* - All slow java loops in the dataset creation and image drawing are replaced with fast vectorized code.
5052
*
5153
* @author Robert Altena
5254
* Many thanks to @tmanthey for constructive feedback and suggestions.
@@ -59,17 +61,11 @@ public class ImageDrawer {
5961
private BufferedImage originalImage;
6062
private JLabel generatedLabel;
6163

62-
private INDArray blueMat; // color channels of he original image.
63-
private INDArray greenMat;
64-
private INDArray redMat;
65-
66-
private INDArray xPixels; // x coordinates of the pixels for the NN.
67-
private INDArray yPixels; // y coordinates of the pixels for the NN.
68-
6964
private INDArray xyOut; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
7065

7166
private Java2DNativeImageLoader j2dNil; //Datavec class used to read and write images to /from INDArrays.
72-
67+
private FastRGB rgb; // helper class for fast access to the image pixels.
68+
private Random random;
7369

7470
private void init() throws Exception {
7571

@@ -78,6 +74,7 @@ private void init() throws Exception {
7874

7975
String localDataPath = DownloaderUtility.DATAEXAMPLES.Download();
8076
originalImage = ImageIO.read(new File(localDataPath, "Mona_Lisa.png"));
77+
8178
//start with a blank image of the same size as the original.
8279
BufferedImage generatedImage = new BufferedImage(originalImage.getWidth(), originalImage.getHeight(), originalImage.getType());
8380

@@ -98,15 +95,13 @@ private void init() throws Exception {
9895
mainFrame.setVisible(true); // Show UI
9996

10097

101-
j2dNil = new Java2DNativeImageLoader(); //Datavec class used to read and write images.
98+
j2dNil = new Java2DNativeImageLoader(); //Datavec class used to write images.
99+
random = new Random();
102100
nn = createNN(); // Create the neural network.
103101
xyOut = calcGrid(); //Create a mesh used to generate the image.
104102

105103
// read the color channels from the original image.
106-
INDArray imageMat = j2dNil.asMatrix(originalImage).castTo(DataType.DOUBLE).div(255.0);
107-
blueMat = imageMat .tensorAlongDimension(1, 0, 2, 3).reshape(width * height, 1);
108-
greenMat = imageMat .tensorAlongDimension(2, 0, 2, 3).reshape(width * height, 1);
109-
redMat = imageMat .tensorAlongDimension(3, 0, 2, 3).reshape(width * height, 1);
104+
rgb = new FastRGB(originalImage);
110105

111106
SwingUtilities.invokeLater(this::onCalc);
112107
}
@@ -127,30 +122,30 @@ private static MultiLayerNetwork createNN() {
127122
int numOutputs = 3 ; //R, G and B value.
128123

129124
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
130-
.seed(seed)
131-
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
132-
.weightInit(WeightInit.XAVIER)
133-
.updater(new Adam(learningRate))
134-
.list()
135-
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
136-
.activation(Activation.LEAKYRELU)
137-
.build())
138-
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
139-
.activation(Activation.LEAKYRELU)
140-
.build())
141-
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
142-
.activation(Activation.LEAKYRELU)
143-
.build())
144-
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
145-
.activation(Activation.LEAKYRELU)
146-
.build())
147-
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
148-
.activation(Activation.LEAKYRELU)
149-
.build())
150-
.layer( new OutputLayer.Builder(LossFunctions.LossFunction.L2)
151-
.activation(Activation.IDENTITY)
152-
.nOut(numOutputs).build())
153-
.build();
125+
.seed(seed)
126+
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
127+
.weightInit(WeightInit.XAVIER)
128+
.updater(new Adam(learningRate))
129+
.list()
130+
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
131+
.activation(Activation.LEAKYRELU)
132+
.build())
133+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
134+
.activation(Activation.LEAKYRELU)
135+
.build())
136+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
137+
.activation(Activation.LEAKYRELU)
138+
.build())
139+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
140+
.activation(Activation.LEAKYRELU)
141+
.build())
142+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
143+
.activation(Activation.LEAKYRELU)
144+
.build())
145+
.layer( new OutputLayer.Builder(LossFunctions.LossFunction.L2)
146+
.activation(Activation.IDENTITY)
147+
.nOut(numOutputs).build())
148+
.build();
154149

155150
MultiLayerNetwork net = new MultiLayerNetwork(conf);
156151
net.init();
@@ -162,8 +157,9 @@ private static MultiLayerNetwork createNN() {
162157
* Training the NN and updating the current graphical output.
163158
*/
164159
private void onCalc(){
165-
int batchSize = 1000;
166-
int numBatches = 10;
160+
// Find a reasonable balance between batch size and number of batches per generated redraw.
161+
int batchSize = 1000; //larger batch size slows the calculation but speeds up the learning per batch
162+
int numBatches = 10; // Drawing the generated image is slow. Doing multiple batches before redrawing increases speed.
167163
for (int i =0; i< numBatches; i++){
168164
DataSet ds = generateDataSet(batchSize);
169165
nn.fit(ds);
@@ -172,11 +168,13 @@ private void onCalc(){
172168
mainFrame.invalidate();
173169
mainFrame.repaint();
174170

175-
SwingUtilities.invokeLater(this::onCalc);
171+
SwingUtilities.invokeLater(this::onCalc); //TODO: move training to a worker thread,
176172
}
177173

178174
/**
179175
* Take a batchsize of random samples from the source image.
176+
* This illustrates how to generate a custom dataset. The normal way of doing this would be to generate a dataset
177+
* of the entire source image, train om shuffled batches from there.
180178
*
181179
* @param batchSize number of sample points to take out of the image.
182180
* @return DeepLearning4J DataSet.
@@ -185,22 +183,22 @@ private DataSet generateDataSet(int batchSize) {
185183
int w = originalImage.getWidth();
186184
int h = originalImage.getHeight();
187185

188-
INDArray xindex = Nd4j.rand(batchSize).muli(w-1).castTo(DataType.UINT32);
189-
INDArray yindex = Nd4j.rand(batchSize).muli(h-1).castTo(DataType.UINT32);
190-
191-
INDArray xPos = xPixels.get(xindex).reshape(batchSize); // Look up the normalized positions pf the pixels.
192-
INDArray yPos = yPixels.get(yindex).reshape(batchSize);
193-
194-
INDArray xy = Nd4j.vstack(xPos, yPos).transpose(); // Create the array that can be fed into the NN.
195-
196-
//Look up the correct colors fot our random pixels.
197-
INDArray xyIndex = yindex.mul(w).add(xindex); //TODO: figure out the 2D version of INDArray.get.
198-
INDArray b = blueMat.get(xyIndex).reshape(batchSize);
199-
INDArray g = greenMat.get(xyIndex).reshape(batchSize);
200-
INDArray r = redMat.get(xyIndex).reshape(batchSize);
201-
INDArray out = Nd4j.vstack(r, g, b).transpose(); // Create the array that can be used for NN training.
202-
203-
return new DataSet(xy, out);
186+
float[][] in = new float[batchSize][2];
187+
float[][] out = new float[batchSize][3];
188+
final int[] i = {0};
189+
Streams.forEachPair(
190+
random.ints(batchSize, 0, w).boxed(),
191+
random.ints(batchSize, 0, h).boxed(),
192+
(a, b) -> {
193+
final short[] parts = rgb.getRGB(a, b);
194+
in[i[0]] = new float[]{((a / (float)w) - 0.5f) * 2f, ((b / (float)h) - 0.5f) * 2f};
195+
out[i[0]] = new float[]{parts[0], parts[1], parts[2]};
196+
i[0]++;
197+
}
198+
);
199+
final INDArray input = Nd4j.create(in);
200+
final INDArray labels = Nd4j.create(out).divi(255);
201+
return new DataSet(input, labels);
204202
}
205203

206204
/**
@@ -211,7 +209,7 @@ private void drawImage() {
211209
int h = originalImage.getHeight();
212210

213211
INDArray out = nn.output(xyOut); // The raw NN output.
214-
BooleanIndexing.replaceWhere(out, 0.0, Conditions.lessThan(0.0)); // Cjip between 0 and 1.
212+
BooleanIndexing.replaceWhere(out, 0.0, Conditions.lessThan(0.0)); // Clip between 0 and 1.
215213
BooleanIndexing.replaceWhere(out, 1.0, Conditions.greaterThan(1.0));
216214
out = out.mul(255).castTo(DataType.BYTE); //convert to bytes.
217215

@@ -231,13 +229,40 @@ private void drawImage() {
231229
private INDArray calcGrid(){
232230
int w = originalImage.getWidth();
233231
int h = originalImage.getHeight();
234-
xPixels = Nd4j.linspace(-1.0, 1.0, w, DataType.DOUBLE);
235-
yPixels = Nd4j.linspace(-1.0, 1.0, h, DataType.DOUBLE);
232+
INDArray xPixels = Nd4j.linspace(-1.0, 1.0, w, DataType.DOUBLE);
233+
INDArray yPixels = Nd4j.linspace(-1.0, 1.0, h, DataType.DOUBLE);
236234
INDArray [] mesh = Nd4j.meshgrid(xPixels, yPixels);
237235

238-
xPixels = xPixels.reshape(w, 1); // This is a hack to work around a bug in INDArray.get()
239-
yPixels = yPixels.reshape(h, 1); // in the dataset generation.
240-
241236
return Nd4j.vstack(mesh[0].ravel(), mesh[1].ravel()).transpose();
242237
}
238+
239+
240+
public class FastRGB {
241+
int width;
242+
int height;
243+
private boolean hasAlphaChannel;
244+
private int pixelLength;
245+
private byte[] pixels;
246+
247+
FastRGB(BufferedImage image) {
248+
pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();
249+
width = image.getWidth();
250+
height = image.getHeight();
251+
hasAlphaChannel = image.getAlphaRaster() != null;
252+
pixelLength = 3;
253+
if (hasAlphaChannel)
254+
pixelLength = 4;
255+
}
256+
257+
short[] getRGB(int x, int y) {
258+
int pos = (y * pixelLength * width) + (x * pixelLength);
259+
short rgb[] = new short[4];
260+
if (hasAlphaChannel)
261+
rgb[3] = (short) (pixels[pos++] & 0xFF); // Alpha
262+
rgb[2] = (short) (pixels[pos++] & 0xFF); // Blue
263+
rgb[1] = (short) (pixels[pos++] & 0xFF); // Green
264+
rgb[0] = (short) (pixels[pos] & 0xFF); // Red
265+
return rgb;
266+
}
267+
}
243268
}

0 commit comments

Comments
 (0)