|
| 1 | +/* ***************************************************************************** |
| 2 | + * Copyright (c) 2020 Konduit, Inc. |
| 3 | + * |
| 4 | + * This program and the accompanying materials are made available under the |
| 5 | + * terms of the Apache License, Version 2.0 which is available at |
| 6 | + * https://www.apache.org/licenses/LICENSE-2.0. |
| 7 | + * |
| 8 | + * Unless required by applicable law or agreed to in writing, software |
| 9 | + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| 10 | + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| 11 | + * License for the specific language governing permissions and limitations |
| 12 | + * under the License. |
| 13 | + * |
| 14 | + * SPDX-License-Identifier: Apache-2.0 |
| 15 | + ******************************************************************************/ |
| 16 | + |
| 17 | +package org.deeplearning4j.examples.dataexamples; |
| 18 | + |
| 19 | +import org.datavec.image.loader.Java2DNativeImageLoader; |
| 20 | +import org.deeplearning4j.examples.download.DownloaderUtility; |
| 21 | +import org.deeplearning4j.nn.api.OptimizationAlgorithm; |
| 22 | +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; |
| 23 | +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; |
| 24 | +import org.deeplearning4j.nn.conf.layers.DenseLayer; |
| 25 | +import org.deeplearning4j.nn.conf.layers.OutputLayer; |
| 26 | +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; |
| 27 | +import org.deeplearning4j.nn.weights.WeightInit; |
| 28 | +import org.nd4j.linalg.activations.Activation; |
| 29 | +import org.nd4j.linalg.api.buffer.DataType; |
| 30 | +import org.nd4j.linalg.api.ndarray.INDArray; |
| 31 | +import org.nd4j.linalg.dataset.DataSet; |
| 32 | +import org.nd4j.linalg.factory.Nd4j; |
| 33 | +import org.nd4j.linalg.indexing.BooleanIndexing; |
| 34 | +import org.nd4j.linalg.indexing.conditions.Conditions; |
| 35 | +import org.nd4j.linalg.learning.config.Adam; |
| 36 | +import org.nd4j.linalg.lossfunctions.LossFunctions; |
| 37 | +import org.nd4j.shade.guava.collect.Streams; |
| 38 | + |
| 39 | +import javax.imageio.ImageIO; |
| 40 | +import javax.swing.*; |
| 41 | +import java.awt.image.BufferedImage; |
| 42 | +import java.awt.image.DataBufferByte; |
| 43 | +import java.io.File; |
| 44 | +import java.util.Random; |
| 45 | + |
| 46 | +/** |
| 47 | + * Application to show a neural network learning to draw an image. |
| 48 | + * Demonstrates how to feed an NN with externally originated data. |
| 49 | + * |
| 50 | + * Updates from previous versions: |
| 51 | + * - Now uses swing. No longer uses JavaFX which caused problems with the OpenJDK. |
| 52 | + * |
| 53 | + * @author Robert Altena |
| 54 | + * Many thanks to @tmanthey for constructive feedback and suggestions. |
| 55 | + */ |
| 56 | +public class ImageDrawer { |
| 57 | + |
| 58 | + private JFrame mainFrame; |
| 59 | + private MultiLayerNetwork nn; // The neural network. |
| 60 | + |
| 61 | + private BufferedImage originalImage; |
| 62 | + private JLabel generatedLabel; |
| 63 | + |
| 64 | + private INDArray xyOut; //x,y grid to calculate the output image. Needs to be calculated once, then re-used. |
| 65 | + |
| 66 | + private Java2DNativeImageLoader j2dNil; //Datavec class used to read and write images to /from INDArrays. |
| 67 | + private FastRGB rgb; // helper class for fast access to the image pixels. |
| 68 | + private Random random; |
| 69 | + |
| 70 | + private void init() throws Exception { |
| 71 | + |
| 72 | + mainFrame = new JFrame("Image drawer example");//creating instance of JFrame |
| 73 | + mainFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); |
| 74 | + |
| 75 | + String localDataPath = DownloaderUtility.DATAEXAMPLES.Download(); |
| 76 | + originalImage = ImageIO.read(new File(localDataPath, "Mona_Lisa.png")); |
| 77 | + |
| 78 | + //start with a blank image of the same size as the original. |
| 79 | + BufferedImage generatedImage = new BufferedImage(originalImage.getWidth(), originalImage.getHeight(), originalImage.getType()); |
| 80 | + |
| 81 | + int width = originalImage.getWidth(); |
| 82 | + int height = originalImage.getHeight(); |
| 83 | + |
| 84 | + final JLabel originalLabel = new JLabel(new ImageIcon(originalImage)); |
| 85 | + generatedLabel = new JLabel(new ImageIcon(generatedImage)); |
| 86 | + |
| 87 | + originalLabel.setBounds(0,0, width, height); |
| 88 | + generatedLabel.setBounds(width, 0, width, height);//x axis, y axis, width, height |
| 89 | + |
| 90 | + mainFrame.add(originalLabel); |
| 91 | + mainFrame.add(generatedLabel); |
| 92 | + |
| 93 | + mainFrame.setSize(2*width, height +25); |
| 94 | + mainFrame.setLayout(null); |
| 95 | + mainFrame.setVisible(true); // Show UI |
| 96 | + |
| 97 | + |
| 98 | + j2dNil = new Java2DNativeImageLoader(); //Datavec class used to write images. |
| 99 | + random = new Random(); |
| 100 | + nn = createNN(); // Create the neural network. |
| 101 | + xyOut = calcGrid(); //Create a mesh used to generate the image. |
| 102 | + |
| 103 | + // read the color channels from the original image. |
| 104 | + rgb = new FastRGB(originalImage); |
| 105 | + |
| 106 | + SwingUtilities.invokeLater(this::onCalc); |
| 107 | + } |
| 108 | + |
| 109 | + public static void main(String[] args) throws Exception { |
| 110 | + ImageDrawer imageDrawer = new ImageDrawer(); |
| 111 | + imageDrawer.init(); |
| 112 | + } |
| 113 | + |
| 114 | + /** |
| 115 | + * Build the Neural network. |
| 116 | + */ |
| 117 | + private static MultiLayerNetwork createNN() { |
| 118 | + int seed = 2345; |
| 119 | + double learningRate = 0.001; |
| 120 | + int numInputs = 2; // x and y. |
| 121 | + int numHiddenNodes = 1000; |
| 122 | + int numOutputs = 3 ; //R, G and B value. |
| 123 | + |
| 124 | + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() |
| 125 | + .seed(seed) |
| 126 | + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) |
| 127 | + .weightInit(WeightInit.XAVIER) |
| 128 | + .updater(new Adam(learningRate)) |
| 129 | + .list() |
| 130 | + .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) |
| 131 | + .activation(Activation.LEAKYRELU) |
| 132 | + .build()) |
| 133 | + .layer(new DenseLayer.Builder().nOut(numHiddenNodes ) |
| 134 | + .activation(Activation.LEAKYRELU) |
| 135 | + .build()) |
| 136 | + .layer(new DenseLayer.Builder().nOut(numHiddenNodes ) |
| 137 | + .activation(Activation.LEAKYRELU) |
| 138 | + .build()) |
| 139 | + .layer(new DenseLayer.Builder().nOut(numHiddenNodes ) |
| 140 | + .activation(Activation.LEAKYRELU) |
| 141 | + .build()) |
| 142 | + .layer(new DenseLayer.Builder().nOut(numHiddenNodes ) |
| 143 | + .activation(Activation.LEAKYRELU) |
| 144 | + .build()) |
| 145 | + .layer( new OutputLayer.Builder(LossFunctions.LossFunction.L2) |
| 146 | + .activation(Activation.IDENTITY) |
| 147 | + .nOut(numOutputs).build()) |
| 148 | + .build(); |
| 149 | + |
| 150 | + MultiLayerNetwork net = new MultiLayerNetwork(conf); |
| 151 | + net.init(); |
| 152 | + |
| 153 | + return net; |
| 154 | + } |
| 155 | + |
| 156 | + /** |
| 157 | + * Training the NN and updating the current graphical output. |
| 158 | + */ |
| 159 | + private void onCalc(){ |
| 160 | + // Find a reasonable balance between batch size and number of batches per generated redraw. |
| 161 | + int batchSize = 1000; //larger batch size slows the calculation but speeds up the learning per batch |
| 162 | + int numBatches = 10; // Drawing the generated image is slow. Doing multiple batches before redrawing increases speed. |
| 163 | + for (int i =0; i< numBatches; i++){ |
| 164 | + DataSet ds = generateDataSet(batchSize); |
| 165 | + nn.fit(ds); |
| 166 | + } |
| 167 | + drawImage(); |
| 168 | + mainFrame.invalidate(); |
| 169 | + mainFrame.repaint(); |
| 170 | + |
| 171 | + SwingUtilities.invokeLater(this::onCalc); //TODO: move training to a worker thread, |
| 172 | + } |
| 173 | + |
| 174 | + /** |
| 175 | + * Take a batchsize of random samples from the source image. |
| 176 | + * This illustrates how to generate a custom dataset. The normal way of doing this would be to generate a dataset |
| 177 | + * of the entire source image, train om shuffled batches from there. |
| 178 | + * |
| 179 | + * @param batchSize number of sample points to take out of the image. |
| 180 | + * @return DeepLearning4J DataSet. |
| 181 | + */ |
| 182 | + private DataSet generateDataSet(int batchSize) { |
| 183 | + int w = originalImage.getWidth(); |
| 184 | + int h = originalImage.getHeight(); |
| 185 | + |
| 186 | + float[][] in = new float[batchSize][2]; |
| 187 | + float[][] out = new float[batchSize][3]; |
| 188 | + final int[] i = {0}; |
| 189 | + Streams.forEachPair( |
| 190 | + random.ints(batchSize, 0, w).boxed(), |
| 191 | + random.ints(batchSize, 0, h).boxed(), |
| 192 | + (a, b) -> { |
| 193 | + final short[] parts = rgb.getRGB(a, b); |
| 194 | + in[i[0]] = new float[]{((a / (float)w) - 0.5f) * 2f, ((b / (float)h) - 0.5f) * 2f}; |
| 195 | + out[i[0]] = new float[]{parts[0], parts[1], parts[2]}; |
| 196 | + i[0]++; |
| 197 | + } |
| 198 | + ); |
| 199 | + final INDArray input = Nd4j.create(in); |
| 200 | + final INDArray labels = Nd4j.create(out).divi(255); |
| 201 | + return new DataSet(input, labels); |
| 202 | + } |
| 203 | + |
| 204 | + /** |
| 205 | + * Make the Neural network draw the image. |
| 206 | + */ |
| 207 | + private void drawImage() { |
| 208 | + int w = originalImage.getWidth(); |
| 209 | + int h = originalImage.getHeight(); |
| 210 | + |
| 211 | + INDArray out = nn.output(xyOut); // The raw NN output. |
| 212 | + BooleanIndexing.replaceWhere(out, 0.0, Conditions.lessThan(0.0)); // Clip between 0 and 1. |
| 213 | + BooleanIndexing.replaceWhere(out, 1.0, Conditions.greaterThan(1.0)); |
| 214 | + out = out.mul(255).castTo(DataType.BYTE); //convert to bytes. |
| 215 | + |
| 216 | + INDArray r = out.getColumn(0); //Extract the individual color layers. |
| 217 | + INDArray g = out.getColumn(1); |
| 218 | + INDArray b = out.getColumn(2); |
| 219 | + |
| 220 | + INDArray imgArr = Nd4j.vstack(b, g, r).reshape(3, h, w); // recombine the colors and reshape to image size. |
| 221 | + |
| 222 | + BufferedImage img = j2dNil.asBufferedImage(imgArr); //update the UI. |
| 223 | + generatedLabel.setIcon(new ImageIcon(img)); |
| 224 | + } |
| 225 | + |
| 226 | + /** |
| 227 | + * The x,y grid to calculate the NN output. Only needs to be calculated once. |
| 228 | + */ |
| 229 | + private INDArray calcGrid(){ |
| 230 | + int w = originalImage.getWidth(); |
| 231 | + int h = originalImage.getHeight(); |
| 232 | + INDArray xPixels = Nd4j.linspace(-1.0, 1.0, w, DataType.DOUBLE); |
| 233 | + INDArray yPixels = Nd4j.linspace(-1.0, 1.0, h, DataType.DOUBLE); |
| 234 | + INDArray [] mesh = Nd4j.meshgrid(xPixels, yPixels); |
| 235 | + |
| 236 | + return Nd4j.vstack(mesh[0].ravel(), mesh[1].ravel()).transpose(); |
| 237 | + } |
| 238 | + |
| 239 | + |
| 240 | + public class FastRGB { |
| 241 | + int width; |
| 242 | + int height; |
| 243 | + private boolean hasAlphaChannel; |
| 244 | + private int pixelLength; |
| 245 | + private byte[] pixels; |
| 246 | + |
| 247 | + FastRGB(BufferedImage image) { |
| 248 | + pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData(); |
| 249 | + width = image.getWidth(); |
| 250 | + height = image.getHeight(); |
| 251 | + hasAlphaChannel = image.getAlphaRaster() != null; |
| 252 | + pixelLength = 3; |
| 253 | + if (hasAlphaChannel) |
| 254 | + pixelLength = 4; |
| 255 | + } |
| 256 | + |
| 257 | + short[] getRGB(int x, int y) { |
| 258 | + int pos = (y * pixelLength * width) + (x * pixelLength); |
| 259 | + short rgb[] = new short[4]; |
| 260 | + if (hasAlphaChannel) |
| 261 | + rgb[3] = (short) (pixels[pos++] & 0xFF); // Alpha |
| 262 | + rgb[2] = (short) (pixels[pos++] & 0xFF); // Blue |
| 263 | + rgb[1] = (short) (pixels[pos++] & 0xFF); // Green |
| 264 | + rgb[0] = (short) (pixels[pos] & 0xFF); // Red |
| 265 | + return rgb; |
| 266 | + } |
| 267 | + } |
| 268 | +} |
0 commit comments