Skip to content

Commit 5e1e912

Browse files
authored
Merge pull request #959 from RobAltena/master
Removes JavaFX from ImageDrawer. Remove java loops.
2 parents 97888ba + c0ad5e8 commit 5e1e912

File tree

2 files changed

+268
-263
lines changed
  • dl4j-examples_javafx/src/main/java/org/deeplearning4j/examples/dataexamples
  • dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples

2 files changed

+268
-263
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
/* *****************************************************************************
2+
* Copyright (c) 2020 Konduit, Inc.
3+
*
4+
* This program and the accompanying materials are made available under the
5+
* terms of the Apache License, Version 2.0 which is available at
6+
* https://www.apache.org/licenses/LICENSE-2.0.
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
* License for the specific language governing permissions and limitations
12+
* under the License.
13+
*
14+
* SPDX-License-Identifier: Apache-2.0
15+
******************************************************************************/
16+
17+
package org.deeplearning4j.examples.dataexamples;
18+
19+
import org.datavec.image.loader.Java2DNativeImageLoader;
20+
import org.deeplearning4j.examples.download.DownloaderUtility;
21+
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
22+
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
23+
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
24+
import org.deeplearning4j.nn.conf.layers.DenseLayer;
25+
import org.deeplearning4j.nn.conf.layers.OutputLayer;
26+
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
27+
import org.deeplearning4j.nn.weights.WeightInit;
28+
import org.nd4j.linalg.activations.Activation;
29+
import org.nd4j.linalg.api.buffer.DataType;
30+
import org.nd4j.linalg.api.ndarray.INDArray;
31+
import org.nd4j.linalg.dataset.DataSet;
32+
import org.nd4j.linalg.factory.Nd4j;
33+
import org.nd4j.linalg.indexing.BooleanIndexing;
34+
import org.nd4j.linalg.indexing.conditions.Conditions;
35+
import org.nd4j.linalg.learning.config.Adam;
36+
import org.nd4j.linalg.lossfunctions.LossFunctions;
37+
import org.nd4j.shade.guava.collect.Streams;
38+
39+
import javax.imageio.ImageIO;
40+
import javax.swing.*;
41+
import java.awt.image.BufferedImage;
42+
import java.awt.image.DataBufferByte;
43+
import java.io.File;
44+
import java.util.Random;
45+
46+
/**
47+
* Application to show a neural network learning to draw an image.
48+
* Demonstrates how to feed an NN with externally originated data.
49+
*
50+
* Updates from previous versions:
51+
* - Now uses swing. No longer uses JavaFX which caused problems with the OpenJDK.
52+
*
53+
* @author Robert Altena
54+
* Many thanks to @tmanthey for constructive feedback and suggestions.
55+
*/
56+
public class ImageDrawer {
57+
58+
private JFrame mainFrame;
59+
private MultiLayerNetwork nn; // The neural network.
60+
61+
private BufferedImage originalImage;
62+
private JLabel generatedLabel;
63+
64+
private INDArray xyOut; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
65+
66+
private Java2DNativeImageLoader j2dNil; //Datavec class used to read and write images to /from INDArrays.
67+
private FastRGB rgb; // helper class for fast access to the image pixels.
68+
private Random random;
69+
70+
private void init() throws Exception {
71+
72+
mainFrame = new JFrame("Image drawer example");//creating instance of JFrame
73+
mainFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
74+
75+
String localDataPath = DownloaderUtility.DATAEXAMPLES.Download();
76+
originalImage = ImageIO.read(new File(localDataPath, "Mona_Lisa.png"));
77+
78+
//start with a blank image of the same size as the original.
79+
BufferedImage generatedImage = new BufferedImage(originalImage.getWidth(), originalImage.getHeight(), originalImage.getType());
80+
81+
int width = originalImage.getWidth();
82+
int height = originalImage.getHeight();
83+
84+
final JLabel originalLabel = new JLabel(new ImageIcon(originalImage));
85+
generatedLabel = new JLabel(new ImageIcon(generatedImage));
86+
87+
originalLabel.setBounds(0,0, width, height);
88+
generatedLabel.setBounds(width, 0, width, height);//x axis, y axis, width, height
89+
90+
mainFrame.add(originalLabel);
91+
mainFrame.add(generatedLabel);
92+
93+
mainFrame.setSize(2*width, height +25);
94+
mainFrame.setLayout(null);
95+
mainFrame.setVisible(true); // Show UI
96+
97+
98+
j2dNil = new Java2DNativeImageLoader(); //Datavec class used to write images.
99+
random = new Random();
100+
nn = createNN(); // Create the neural network.
101+
xyOut = calcGrid(); //Create a mesh used to generate the image.
102+
103+
// read the color channels from the original image.
104+
rgb = new FastRGB(originalImage);
105+
106+
SwingUtilities.invokeLater(this::onCalc);
107+
}
108+
109+
public static void main(String[] args) throws Exception {
110+
ImageDrawer imageDrawer = new ImageDrawer();
111+
imageDrawer.init();
112+
}
113+
114+
/**
115+
* Build the Neural network.
116+
*/
117+
private static MultiLayerNetwork createNN() {
118+
int seed = 2345;
119+
double learningRate = 0.001;
120+
int numInputs = 2; // x and y.
121+
int numHiddenNodes = 1000;
122+
int numOutputs = 3 ; //R, G and B value.
123+
124+
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
125+
.seed(seed)
126+
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
127+
.weightInit(WeightInit.XAVIER)
128+
.updater(new Adam(learningRate))
129+
.list()
130+
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
131+
.activation(Activation.LEAKYRELU)
132+
.build())
133+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
134+
.activation(Activation.LEAKYRELU)
135+
.build())
136+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
137+
.activation(Activation.LEAKYRELU)
138+
.build())
139+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
140+
.activation(Activation.LEAKYRELU)
141+
.build())
142+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
143+
.activation(Activation.LEAKYRELU)
144+
.build())
145+
.layer( new OutputLayer.Builder(LossFunctions.LossFunction.L2)
146+
.activation(Activation.IDENTITY)
147+
.nOut(numOutputs).build())
148+
.build();
149+
150+
MultiLayerNetwork net = new MultiLayerNetwork(conf);
151+
net.init();
152+
153+
return net;
154+
}
155+
156+
/**
157+
* Training the NN and updating the current graphical output.
158+
*/
159+
private void onCalc(){
160+
// Find a reasonable balance between batch size and number of batches per generated redraw.
161+
int batchSize = 1000; //larger batch size slows the calculation but speeds up the learning per batch
162+
int numBatches = 10; // Drawing the generated image is slow. Doing multiple batches before redrawing increases speed.
163+
for (int i =0; i< numBatches; i++){
164+
DataSet ds = generateDataSet(batchSize);
165+
nn.fit(ds);
166+
}
167+
drawImage();
168+
mainFrame.invalidate();
169+
mainFrame.repaint();
170+
171+
SwingUtilities.invokeLater(this::onCalc); //TODO: move training to a worker thread,
172+
}
173+
174+
/**
175+
* Take a batchsize of random samples from the source image.
176+
* This illustrates how to generate a custom dataset. The normal way of doing this would be to generate a dataset
177+
* of the entire source image, train om shuffled batches from there.
178+
*
179+
* @param batchSize number of sample points to take out of the image.
180+
* @return DeepLearning4J DataSet.
181+
*/
182+
private DataSet generateDataSet(int batchSize) {
183+
int w = originalImage.getWidth();
184+
int h = originalImage.getHeight();
185+
186+
float[][] in = new float[batchSize][2];
187+
float[][] out = new float[batchSize][3];
188+
final int[] i = {0};
189+
Streams.forEachPair(
190+
random.ints(batchSize, 0, w).boxed(),
191+
random.ints(batchSize, 0, h).boxed(),
192+
(a, b) -> {
193+
final short[] parts = rgb.getRGB(a, b);
194+
in[i[0]] = new float[]{((a / (float)w) - 0.5f) * 2f, ((b / (float)h) - 0.5f) * 2f};
195+
out[i[0]] = new float[]{parts[0], parts[1], parts[2]};
196+
i[0]++;
197+
}
198+
);
199+
final INDArray input = Nd4j.create(in);
200+
final INDArray labels = Nd4j.create(out).divi(255);
201+
return new DataSet(input, labels);
202+
}
203+
204+
/**
205+
* Make the Neural network draw the image.
206+
*/
207+
private void drawImage() {
208+
int w = originalImage.getWidth();
209+
int h = originalImage.getHeight();
210+
211+
INDArray out = nn.output(xyOut); // The raw NN output.
212+
BooleanIndexing.replaceWhere(out, 0.0, Conditions.lessThan(0.0)); // Clip between 0 and 1.
213+
BooleanIndexing.replaceWhere(out, 1.0, Conditions.greaterThan(1.0));
214+
out = out.mul(255).castTo(DataType.BYTE); //convert to bytes.
215+
216+
INDArray r = out.getColumn(0); //Extract the individual color layers.
217+
INDArray g = out.getColumn(1);
218+
INDArray b = out.getColumn(2);
219+
220+
INDArray imgArr = Nd4j.vstack(b, g, r).reshape(3, h, w); // recombine the colors and reshape to image size.
221+
222+
BufferedImage img = j2dNil.asBufferedImage(imgArr); //update the UI.
223+
generatedLabel.setIcon(new ImageIcon(img));
224+
}
225+
226+
/**
227+
* The x,y grid to calculate the NN output. Only needs to be calculated once.
228+
*/
229+
private INDArray calcGrid(){
230+
int w = originalImage.getWidth();
231+
int h = originalImage.getHeight();
232+
INDArray xPixels = Nd4j.linspace(-1.0, 1.0, w, DataType.DOUBLE);
233+
INDArray yPixels = Nd4j.linspace(-1.0, 1.0, h, DataType.DOUBLE);
234+
INDArray [] mesh = Nd4j.meshgrid(xPixels, yPixels);
235+
236+
return Nd4j.vstack(mesh[0].ravel(), mesh[1].ravel()).transpose();
237+
}
238+
239+
240+
public class FastRGB {
241+
int width;
242+
int height;
243+
private boolean hasAlphaChannel;
244+
private int pixelLength;
245+
private byte[] pixels;
246+
247+
FastRGB(BufferedImage image) {
248+
pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();
249+
width = image.getWidth();
250+
height = image.getHeight();
251+
hasAlphaChannel = image.getAlphaRaster() != null;
252+
pixelLength = 3;
253+
if (hasAlphaChannel)
254+
pixelLength = 4;
255+
}
256+
257+
short[] getRGB(int x, int y) {
258+
int pos = (y * pixelLength * width) + (x * pixelLength);
259+
short rgb[] = new short[4];
260+
if (hasAlphaChannel)
261+
rgb[3] = (short) (pixels[pos++] & 0xFF); // Alpha
262+
rgb[2] = (short) (pixels[pos++] & 0xFF); // Blue
263+
rgb[1] = (short) (pixels[pos++] & 0xFF); // Green
264+
rgb[0] = (short) (pixels[pos] & 0xFF); // Red
265+
return rgb;
266+
}
267+
}
268+
}

0 commit comments

Comments
 (0)