34
34
import org .nd4j .linalg .indexing .conditions .Conditions ;
35
35
import org .nd4j .linalg .learning .config .Adam ;
36
36
import org .nd4j .linalg .lossfunctions .LossFunctions ;
37
+ import org .nd4j .shade .guava .collect .Streams ;
37
38
38
39
import javax .imageio .ImageIO ;
39
40
import javax .swing .*;
40
41
import java .awt .image .BufferedImage ;
42
+ import java .awt .image .DataBufferByte ;
41
43
import java .io .File ;
44
+ import java .util .Random ;
42
45
43
46
/**
44
47
* Application to show a neural network learning to draw an image.
45
48
* Demonstrates how to feed an NN with externally originated data.
46
49
*
47
50
* Updates from previous versions:
48
51
* - Now uses swing. No longer uses JavaFX which caused problems with the OpenJDK.
49
- * - All slow java loops in the dataset creation and image drawing are replaced with fast vectorized code.
50
52
*
51
53
* @author Robert Altena
52
54
* Many thanks to @tmanthey for constructive feedback and suggestions.
@@ -59,17 +61,11 @@ public class ImageDrawer {
59
61
private BufferedImage originalImage ;
60
62
private JLabel generatedLabel ;
61
63
62
- private INDArray blueMat ; // color channels of he original image.
63
- private INDArray greenMat ;
64
- private INDArray redMat ;
65
-
66
- private INDArray xPixels ; // x coordinates of the pixels for the NN.
67
- private INDArray yPixels ; // y coordinates of the pixels for the NN.
68
-
69
64
private INDArray xyOut ; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.
70
65
71
66
private Java2DNativeImageLoader j2dNil ; //Datavec class used to read and write images to /from INDArrays.
72
-
67
+ private FastRGB rgb ; // helper class for fast access to the image pixels.
68
+ private Random random ;
73
69
74
70
private void init () throws Exception {
75
71
@@ -78,6 +74,7 @@ private void init() throws Exception {
78
74
79
75
String localDataPath = DownloaderUtility .DATAEXAMPLES .Download ();
80
76
originalImage = ImageIO .read (new File (localDataPath , "Mona_Lisa.png" ));
77
+
81
78
//start with a blank image of the same size as the original.
82
79
BufferedImage generatedImage = new BufferedImage (originalImage .getWidth (), originalImage .getHeight (), originalImage .getType ());
83
80
@@ -98,15 +95,13 @@ private void init() throws Exception {
98
95
mainFrame .setVisible (true ); // Show UI
99
96
100
97
101
- j2dNil = new Java2DNativeImageLoader (); //Datavec class used to read and write images.
98
+ j2dNil = new Java2DNativeImageLoader (); //Datavec class used to write images.
99
+ random = new Random ();
102
100
nn = createNN (); // Create the neural network.
103
101
xyOut = calcGrid (); //Create a mesh used to generate the image.
104
102
105
103
// read the color channels from the original image.
106
- INDArray imageMat = j2dNil .asMatrix (originalImage ).castTo (DataType .DOUBLE ).div (255.0 );
107
- blueMat = imageMat .tensorAlongDimension (1 , 0 , 2 , 3 ).reshape (width * height , 1 );
108
- greenMat = imageMat .tensorAlongDimension (2 , 0 , 2 , 3 ).reshape (width * height , 1 );
109
- redMat = imageMat .tensorAlongDimension (3 , 0 , 2 , 3 ).reshape (width * height , 1 );
104
+ rgb = new FastRGB (originalImage );
110
105
111
106
SwingUtilities .invokeLater (this ::onCalc );
112
107
}
@@ -127,30 +122,30 @@ private static MultiLayerNetwork createNN() {
127
122
int numOutputs = 3 ; //R, G and B value.
128
123
129
124
MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
130
- .seed (seed )
131
- .optimizationAlgo (OptimizationAlgorithm .STOCHASTIC_GRADIENT_DESCENT )
132
- .weightInit (WeightInit .XAVIER )
133
- .updater (new Adam (learningRate ))
134
- .list ()
135
- .layer (new DenseLayer .Builder ().nIn (numInputs ).nOut (numHiddenNodes )
136
- .activation (Activation .LEAKYRELU )
137
- .build ())
138
- .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
139
- .activation (Activation .LEAKYRELU )
140
- .build ())
141
- .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
142
- .activation (Activation .LEAKYRELU )
143
- .build ())
144
- .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
145
- .activation (Activation .LEAKYRELU )
146
- .build ())
147
- .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
148
- .activation (Activation .LEAKYRELU )
149
- .build ())
150
- .layer ( new OutputLayer .Builder (LossFunctions .LossFunction .L2 )
151
- .activation (Activation .IDENTITY )
152
- .nOut (numOutputs ).build ())
153
- .build ();
125
+ .seed (seed )
126
+ .optimizationAlgo (OptimizationAlgorithm .STOCHASTIC_GRADIENT_DESCENT )
127
+ .weightInit (WeightInit .XAVIER )
128
+ .updater (new Adam (learningRate ))
129
+ .list ()
130
+ .layer (new DenseLayer .Builder ().nIn (numInputs ).nOut (numHiddenNodes )
131
+ .activation (Activation .LEAKYRELU )
132
+ .build ())
133
+ .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
134
+ .activation (Activation .LEAKYRELU )
135
+ .build ())
136
+ .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
137
+ .activation (Activation .LEAKYRELU )
138
+ .build ())
139
+ .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
140
+ .activation (Activation .LEAKYRELU )
141
+ .build ())
142
+ .layer (new DenseLayer .Builder ().nOut (numHiddenNodes )
143
+ .activation (Activation .LEAKYRELU )
144
+ .build ())
145
+ .layer ( new OutputLayer .Builder (LossFunctions .LossFunction .L2 )
146
+ .activation (Activation .IDENTITY )
147
+ .nOut (numOutputs ).build ())
148
+ .build ();
154
149
155
150
MultiLayerNetwork net = new MultiLayerNetwork (conf );
156
151
net .init ();
@@ -162,8 +157,9 @@ private static MultiLayerNetwork createNN() {
162
157
* Training the NN and updating the current graphical output.
163
158
*/
164
159
private void onCalc (){
165
- int batchSize = 1000 ;
166
- int numBatches = 10 ;
160
+ // Find a reasonable balance between batch size and number of batches per generated redraw.
161
+ int batchSize = 1000 ; //larger batch size slows the calculation but speeds up the learning per batch
162
+ int numBatches = 10 ; // Drawing the generated image is slow. Doing multiple batches before redrawing increases speed.
167
163
for (int i =0 ; i < numBatches ; i ++){
168
164
DataSet ds = generateDataSet (batchSize );
169
165
nn .fit (ds );
@@ -172,11 +168,13 @@ private void onCalc(){
172
168
mainFrame .invalidate ();
173
169
mainFrame .repaint ();
174
170
175
- SwingUtilities .invokeLater (this ::onCalc );
171
+ SwingUtilities .invokeLater (this ::onCalc ); //TODO: move training to a worker thread,
176
172
}
177
173
178
174
/**
179
175
* Take a batchsize of random samples from the source image.
176
+ * This illustrates how to generate a custom dataset. The normal way of doing this would be to generate a dataset
177
+ * of the entire source image, train om shuffled batches from there.
180
178
*
181
179
* @param batchSize number of sample points to take out of the image.
182
180
* @return DeepLearning4J DataSet.
@@ -185,22 +183,22 @@ private DataSet generateDataSet(int batchSize) {
185
183
int w = originalImage .getWidth ();
186
184
int h = originalImage .getHeight ();
187
185
188
- INDArray xindex = Nd4j . rand ( batchSize ). muli ( w - 1 ). castTo ( DataType . UINT32 ) ;
189
- INDArray yindex = Nd4j . rand ( batchSize ). muli ( h - 1 ). castTo ( DataType . UINT32 ) ;
190
-
191
- INDArray xPos = xPixels . get ( xindex ). reshape ( batchSize ); // Look up the normalized positions pf the pixels.
192
- INDArray yPos = yPixels . get ( yindex ). reshape ( batchSize );
193
-
194
- INDArray xy = Nd4j . vstack ( xPos , yPos ). transpose (); // Create the array that can be fed into the NN.
195
-
196
- //Look up the correct colors fot our random pixels.
197
- INDArray xyIndex = yindex . mul ( w ). add ( xindex ); //TODO: figure out the 2D version of INDArray.get.
198
- INDArray b = blueMat . get ( xyIndex ). reshape ( batchSize ) ;
199
- INDArray g = greenMat . get ( xyIndex ). reshape ( batchSize );
200
- INDArray r = redMat . get ( xyIndex ). reshape ( batchSize );
201
- INDArray out = Nd4j .vstack ( r , g , b ). transpose (); // Create the array that can be used for NN training.
202
-
203
- return new DataSet (xy , out );
186
+ float [][] in = new float [ batchSize ][ 2 ] ;
187
+ float [][] out = new float [ batchSize ][ 3 ] ;
188
+ final int [] i = { 0 };
189
+ Streams . forEachPair (
190
+ random . ints ( batchSize , 0 , w ). boxed (),
191
+ random . ints ( batchSize , 0 , h ). boxed (),
192
+ ( a , b ) -> {
193
+ final short [] parts = rgb . getRGB ( a , b );
194
+ in [ i [ 0 ]] = new float []{(( a / ( float ) w ) - 0.5f ) * 2f , (( b / ( float ) h ) - 0.5f ) * 2f };
195
+ out [ i [ 0 ]] = new float []{ parts [ 0 ], parts [ 1 ], parts [ 2 ]};
196
+ i [ 0 ]++ ;
197
+ }
198
+ );
199
+ final INDArray input = Nd4j .create ( in );
200
+ final INDArray labels = Nd4j . create ( out ). divi ( 255 );
201
+ return new DataSet (input , labels );
204
202
}
205
203
206
204
/**
@@ -211,7 +209,7 @@ private void drawImage() {
211
209
int h = originalImage .getHeight ();
212
210
213
211
INDArray out = nn .output (xyOut ); // The raw NN output.
214
- BooleanIndexing .replaceWhere (out , 0.0 , Conditions .lessThan (0.0 )); // Cjip between 0 and 1.
212
+ BooleanIndexing .replaceWhere (out , 0.0 , Conditions .lessThan (0.0 )); // Clip between 0 and 1.
215
213
BooleanIndexing .replaceWhere (out , 1.0 , Conditions .greaterThan (1.0 ));
216
214
out = out .mul (255 ).castTo (DataType .BYTE ); //convert to bytes.
217
215
@@ -231,13 +229,40 @@ private void drawImage() {
231
229
private INDArray calcGrid (){
232
230
int w = originalImage .getWidth ();
233
231
int h = originalImage .getHeight ();
234
- xPixels = Nd4j .linspace (-1.0 , 1.0 , w , DataType .DOUBLE );
235
- yPixels = Nd4j .linspace (-1.0 , 1.0 , h , DataType .DOUBLE );
232
+ INDArray xPixels = Nd4j .linspace (-1.0 , 1.0 , w , DataType .DOUBLE );
233
+ INDArray yPixels = Nd4j .linspace (-1.0 , 1.0 , h , DataType .DOUBLE );
236
234
INDArray [] mesh = Nd4j .meshgrid (xPixels , yPixels );
237
235
238
- xPixels = xPixels .reshape (w , 1 ); // This is a hack to work around a bug in INDArray.get()
239
- yPixels = yPixels .reshape (h , 1 ); // in the dataset generation.
240
-
241
236
return Nd4j .vstack (mesh [0 ].ravel (), mesh [1 ].ravel ()).transpose ();
242
237
}
238
+
239
+
240
+ public class FastRGB {
241
+ int width ;
242
+ int height ;
243
+ private boolean hasAlphaChannel ;
244
+ private int pixelLength ;
245
+ private byte [] pixels ;
246
+
247
+ FastRGB (BufferedImage image ) {
248
+ pixels = ((DataBufferByte ) image .getRaster ().getDataBuffer ()).getData ();
249
+ width = image .getWidth ();
250
+ height = image .getHeight ();
251
+ hasAlphaChannel = image .getAlphaRaster () != null ;
252
+ pixelLength = 3 ;
253
+ if (hasAlphaChannel )
254
+ pixelLength = 4 ;
255
+ }
256
+
257
+ short [] getRGB (int x , int y ) {
258
+ int pos = (y * pixelLength * width ) + (x * pixelLength );
259
+ short rgb [] = new short [4 ];
260
+ if (hasAlphaChannel )
261
+ rgb [3 ] = (short ) (pixels [pos ++] & 0xFF ); // Alpha
262
+ rgb [2 ] = (short ) (pixels [pos ++] & 0xFF ); // Blue
263
+ rgb [1 ] = (short ) (pixels [pos ++] & 0xFF ); // Green
264
+ rgb [0 ] = (short ) (pixels [pos ] & 0xFF ); // Red
265
+ return rgb ;
266
+ }
267
+ }
243
268
}
0 commit comments