Skip to content

Commit f503fed

Browse files
committed
Removes JavaFX from ImageDrawer. Remove java loops.
Signed-off-by: Robert Altena <[email protected]>
1 parent 97888ba commit f503fed

File tree

2 files changed

+243
-263
lines changed
  • dl4j-examples_javafx/src/main/java/org/deeplearning4j/examples/dataexamples
  • dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples

2 files changed

+243
-263
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/* *****************************************************************************
2+
* Copyright (c) 2020 Konduit, Inc.
3+
*
4+
* This program and the accompanying materials are made available under the
5+
* terms of the Apache License, Version 2.0 which is available at
6+
* https://www.apache.org/licenses/LICENSE-2.0.
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
* License for the specific language governing permissions and limitations
12+
* under the License.
13+
*
14+
* SPDX-License-Identifier: Apache-2.0
15+
******************************************************************************/
16+
17+
package org.deeplearning4j.examples.dataexamples;
18+
19+
import org.datavec.image.loader.Java2DNativeImageLoader;
20+
import org.deeplearning4j.examples.download.DownloaderUtility;
21+
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
22+
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
23+
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
24+
import org.deeplearning4j.nn.conf.layers.DenseLayer;
25+
import org.deeplearning4j.nn.conf.layers.OutputLayer;
26+
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
27+
import org.deeplearning4j.nn.weights.WeightInit;
28+
import org.nd4j.linalg.activations.Activation;
29+
import org.nd4j.linalg.api.buffer.DataType;
30+
import org.nd4j.linalg.api.ndarray.INDArray;
31+
import org.nd4j.linalg.dataset.DataSet;
32+
import org.nd4j.linalg.factory.Nd4j;
33+
import org.nd4j.linalg.indexing.BooleanIndexing;
34+
import org.nd4j.linalg.indexing.conditions.Conditions;
35+
import org.nd4j.linalg.learning.config.Adam;
36+
import org.nd4j.linalg.lossfunctions.LossFunctions;
37+
38+
import javax.imageio.ImageIO;
39+
import javax.swing.*;
40+
import java.awt.image.BufferedImage;
41+
import java.io.File;
42+
43+
/**
44+
* Application to show a neural network learning to draw an image.
45+
* Demonstrates how to feed an NN with externally originated data.
46+
*
47+
* Updates from previous versions:
48+
* - Now uses swing. No longer uses JavaFX which caused problems with the OpenJDK.
49+
* - All slow java loops in the dataset creation and image drawing are replaced with fast vectorized code.
50+
*
51+
* @author Robert Altena
52+
* Many thanks to @tmanthey for constructive feedback and suggestions.
53+
*/
54+
public class ImageDrawer {
55+
56+
private JFrame mainFrame;
57+
private MultiLayerNetwork nn; // The neural network.
58+
59+
private BufferedImage originalImage;
60+
private JLabel generatedLabel;
61+
62+
private INDArray blueMat; // color channels of he original image.
63+
private INDArray greenMat;
64+
private INDArray redMat;
65+
66+
private INDArray xPixels; // x coordinates of the pixels for the NN.
67+
private INDArray yPixels; // y coordinates of the pixels for the NN.
68+
69+
private INDArray xyOut; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
70+
71+
private Java2DNativeImageLoader j2dNil; //Datavec class used to read and write images to /from INDArrays.
72+
73+
74+
private void init() throws Exception {
75+
76+
mainFrame = new JFrame("Image drawer example");//creating instance of JFrame
77+
mainFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
78+
79+
String localDataPath = DownloaderUtility.DATAEXAMPLES.Download();
80+
originalImage = ImageIO.read(new File(localDataPath, "Mona_Lisa.png"));
81+
//start with a blank image of the same size as the original.
82+
BufferedImage generatedImage = new BufferedImage(originalImage.getWidth(), originalImage.getHeight(), originalImage.getType());
83+
84+
int width = originalImage.getWidth();
85+
int height = originalImage.getHeight();
86+
87+
final JLabel originalLabel = new JLabel(new ImageIcon(originalImage));
88+
generatedLabel = new JLabel(new ImageIcon(generatedImage));
89+
90+
originalLabel.setBounds(0,0, width, height);
91+
generatedLabel.setBounds(width, 0, width, height);//x axis, y axis, width, height
92+
93+
mainFrame.add(originalLabel);
94+
mainFrame.add(generatedLabel);
95+
96+
mainFrame.setSize(2*width, height +25);
97+
mainFrame.setLayout(null);
98+
mainFrame.setVisible(true); // Show UI
99+
100+
101+
j2dNil = new Java2DNativeImageLoader(); //Datavec class used to read and write images.
102+
nn = createNN(); // Create the neural network.
103+
xyOut = calcGrid(); //Create a mesh used to generate the image.
104+
105+
// read the color channels from the original image.
106+
INDArray imageMat = j2dNil.asMatrix(originalImage).castTo(DataType.DOUBLE).div(255.0);
107+
blueMat = imageMat .tensorAlongDimension(1, 0, 2, 3).reshape(width * height, 1);
108+
greenMat = imageMat .tensorAlongDimension(2, 0, 2, 3).reshape(width * height, 1);
109+
redMat = imageMat .tensorAlongDimension(3, 0, 2, 3).reshape(width * height, 1);
110+
111+
SwingUtilities.invokeLater(this::onCalc);
112+
}
113+
114+
public static void main(String[] args) throws Exception {
115+
ImageDrawer imageDrawer = new ImageDrawer();
116+
imageDrawer.init();
117+
}
118+
119+
/**
120+
* Build the Neural network.
121+
*/
122+
private static MultiLayerNetwork createNN() {
123+
int seed = 2345;
124+
double learningRate = 0.001;
125+
int numInputs = 2; // x and y.
126+
int numHiddenNodes = 1000;
127+
int numOutputs = 3 ; //R, G and B value.
128+
129+
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
130+
.seed(seed)
131+
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
132+
.weightInit(WeightInit.XAVIER)
133+
.updater(new Adam(learningRate))
134+
.list()
135+
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
136+
.activation(Activation.LEAKYRELU)
137+
.build())
138+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
139+
.activation(Activation.LEAKYRELU)
140+
.build())
141+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
142+
.activation(Activation.LEAKYRELU)
143+
.build())
144+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
145+
.activation(Activation.LEAKYRELU)
146+
.build())
147+
.layer(new DenseLayer.Builder().nOut(numHiddenNodes )
148+
.activation(Activation.LEAKYRELU)
149+
.build())
150+
.layer( new OutputLayer.Builder(LossFunctions.LossFunction.L2)
151+
.activation(Activation.IDENTITY)
152+
.nOut(numOutputs).build())
153+
.build();
154+
155+
MultiLayerNetwork net = new MultiLayerNetwork(conf);
156+
net.init();
157+
158+
return net;
159+
}
160+
161+
/**
162+
* Training the NN and updating the current graphical output.
163+
*/
164+
private void onCalc(){
165+
int batchSize = 1000;
166+
int numBatches = 10;
167+
for (int i =0; i< numBatches; i++){
168+
DataSet ds = generateDataSet(batchSize);
169+
nn.fit(ds);
170+
}
171+
drawImage();
172+
mainFrame.invalidate();
173+
mainFrame.repaint();
174+
175+
SwingUtilities.invokeLater(this::onCalc);
176+
}
177+
178+
/**
179+
* Take a batchsize of random samples from the source image.
180+
*
181+
* @param batchSize number of sample points to take out of the image.
182+
* @return DeepLearning4J DataSet.
183+
*/
184+
private DataSet generateDataSet(int batchSize) {
185+
int w = originalImage.getWidth();
186+
int h = originalImage.getHeight();
187+
188+
INDArray xindex = Nd4j.rand(batchSize).muli(w-1).castTo(DataType.UINT32);
189+
INDArray yindex = Nd4j.rand(batchSize).muli(h-1).castTo(DataType.UINT32);
190+
191+
INDArray xPos = xPixels.get(xindex).reshape(batchSize); // Look up the normalized positions pf the pixels.
192+
INDArray yPos = yPixels.get(yindex).reshape(batchSize);
193+
194+
INDArray xy = Nd4j.vstack(xPos, yPos).transpose(); // Create the array that can be fed into the NN.
195+
196+
//Look up the correct colors fot our random pixels.
197+
INDArray xyIndex = yindex.mul(w).add(xindex); //TODO: figure out the 2D version of INDArray.get.
198+
INDArray b = blueMat.get(xyIndex).reshape(batchSize);
199+
INDArray g = greenMat.get(xyIndex).reshape(batchSize);
200+
INDArray r = redMat.get(xyIndex).reshape(batchSize);
201+
INDArray out = Nd4j.vstack(r, g, b).transpose(); // Create the array that can be used for NN training.
202+
203+
return new DataSet(xy, out);
204+
}
205+
206+
/**
207+
* Make the Neural network draw the image.
208+
*/
209+
private void drawImage() {
210+
int w = originalImage.getWidth();
211+
int h = originalImage.getHeight();
212+
213+
INDArray out = nn.output(xyOut); // The raw NN output.
214+
BooleanIndexing.replaceWhere(out, 0.0, Conditions.lessThan(0.0)); // Cjip between 0 and 1.
215+
BooleanIndexing.replaceWhere(out, 1.0, Conditions.greaterThan(1.0));
216+
out = out.mul(255).castTo(DataType.BYTE); //convert to bytes.
217+
218+
INDArray r = out.getColumn(0); //Extract the individual color layers.
219+
INDArray g = out.getColumn(1);
220+
INDArray b = out.getColumn(2);
221+
222+
INDArray imgArr = Nd4j.vstack(b, g, r).reshape(3, h, w); // recombine the colors and reshape to image size.
223+
224+
BufferedImage img = j2dNil.asBufferedImage(imgArr); //update the UI.
225+
generatedLabel.setIcon(new ImageIcon(img));
226+
}
227+
228+
/**
229+
* The x,y grid to calculate the NN output. Only needs to be calculated once.
230+
*/
231+
private INDArray calcGrid(){
232+
int w = originalImage.getWidth();
233+
int h = originalImage.getHeight();
234+
xPixels = Nd4j.linspace(-1.0, 1.0, w, DataType.DOUBLE);
235+
yPixels = Nd4j.linspace(-1.0, 1.0, h, DataType.DOUBLE);
236+
INDArray [] mesh = Nd4j.meshgrid(xPixels, yPixels);
237+
238+
xPixels = xPixels.reshape(w, 1); // This is a hack to work around a bug in INDArray.get()
239+
yPixels = yPixels.reshape(h, 1); // in the dataset generation.
240+
241+
return Nd4j.vstack(mesh[0].ravel(), mesh[1].ravel()).transpose();
242+
}
243+
}

0 commit comments

Comments
 (0)