@@ -55,7 +55,7 @@ public class GenerateTxtCharCompGraphModel {
5555
5656 @ SuppressWarnings ("ConstantConditions" )
5757 public static void main (String [] args ) throws Exception {
58- int lstmLayerSize = 200 ; //Number of units in each LSTM layer
58+ int lstmLayerSize = 77 ; //Number of units in each LSTM layer
5959 int miniBatchSize = 32 ; //Size of mini batch to use when training
6060 int exampleLength = 1000 ; //Length of each training example sequence to use. This could certainly be increased
6161 int tbpttLength = 50 ; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters
@@ -90,18 +90,20 @@ public static void main(String[] args ) throws Exception {
9090 //Output layer, name "outputlayer" with inputs from the two layers called "first" and "second"
9191 .addLayer ("outputLayer" , new RnnOutputLayer .Builder (LossFunctions .LossFunction .MCXENT )
9292 .activation (Activation .SOFTMAX )
93- .nIn (2 * lstmLayerSize ).nOut (nOut ).build (), "first" ,"second" )
93+ .nIn (lstmLayerSize ).nOut (lstmLayerSize ).build (),"second" )
9494 .setOutputs ("outputLayer" ) //List the output. For a ComputationGraph with multiple outputs, this also defines the input array orders
95- .backpropType (BackpropType .TruncatedBPTT ).tBPTTForwardLength (tbpttLength ).tBPTTBackwardLength (tbpttLength )
95+ .backpropType (BackpropType .TruncatedBPTT )
96+ .tBPTTForwardLength (tbpttLength ).tBPTTBackwardLength (tbpttLength )
9697 .build ();
9798
9899 ComputationGraph net = new ComputationGraph (conf );
99100 net .init ();
100101 net .setListeners (new ScoreIterationListener (1 ));
102+ System .out .println (net .summary ());
101103
102104 //Print the number of parameters in the network (and for each layer)
103105 long totalNumParams = 0 ;
104- for ( int i = 0 ; i < net .getNumLayers (); i ++ ) {
106+ for ( int i = 0 ; i < net .getNumLayers (); i ++) {
105107 long nParams = net .getLayer (i ).numParams ();
106108 System .out .println ("Number of parameters in layer " + i + ": " + nParams );
107109 totalNumParams += nParams ;
@@ -110,16 +112,18 @@ public static void main(String[] args ) throws Exception {
110112
111113 //Do training, and then generate and print samples from network
112114 int miniBatchNumber = 0 ;
113- for ( int i = 0 ; i < numEpochs ; i ++ ) {
115+ for ( int i = 0 ; i < numEpochs ; i ++) {
114116 while (iter .hasNext ()){
115117 DataSet ds = iter .next ();
118+ System .out .println ("Input shape " + ds .getFeatures ().shapeInfoToString ());
119+ System .out .println ("Labels " + ds .getLabels ().shapeInfoToString ());
116120 net .fit (ds );
117121 if (++miniBatchNumber % generateSamplesEveryNMinibatches == 0 ){
118122 System .out .println ("--------------------" );
119123 System .out .println ("Completed " + miniBatchNumber + " minibatches of size " + miniBatchSize + "x" + exampleLength + " characters" );
120124 System .out .println ("Sampling characters from network given initialization \" " + (generationInitialization == null ? "" : generationInitialization ) + "\" " );
121125 String [] samples = sampleCharactersFromNetwork (generationInitialization ,net ,iter ,rng ,nCharactersToSample ,nSamplesToGenerate );
122- for ( int j = 0 ; j < samples .length ; j ++ ) {
126+ for ( int j = 0 ; j < samples .length ; j ++) {
123127 System .out .println ("----- Sample " + j + " -----" );
124128 System .out .println (samples [j ]);
125129 System .out .println ();
@@ -135,7 +139,7 @@ public static void main(String[] args ) throws Exception {
135139
136140 /** Generate a sample from the network, given an (optional, possibly null) initialization. Initialization
137141 * can be used to 'prime' the RNN with a sequence you want to extend/continue.<br>
138- * Note that the initalization is used for all samples
142+ * Note that the initialization is used for all samples
139143 * @param initialization String, may be null. If null, select a random character as initialization for all samples
140144 * @param charactersToSample Number of characters to sample from network (excluding initialization)
141145 * @param net MultiLayerNetwork with one or more LSTM/RNN layers and a softmax output layer
@@ -151,9 +155,9 @@ private static String[] sampleCharactersFromNetwork( String initialization, Comp
151155 //Create input for initialization
152156 INDArray initializationInput = Nd4j .zeros (numSamples , iter .inputColumns (), initialization .length ());
153157 char [] init = initialization .toCharArray ();
154- for ( int i =0 ; i <init .length ; i ++ ) {
158+ for ( int i =0 ; i <init .length ; i ++) {
155159 int idx = iter .convertCharacterToIndex (init [i ]);
156- for ( int j = 0 ; j <numSamples ; j ++ ){
160+ for ( int j = 0 ; j <numSamples ; j ++ ){
157161 initializationInput .putScalar (new int []{j ,idx ,i }, 1.0f );
158162 }
159163 }
@@ -167,13 +171,13 @@ private static String[] sampleCharactersFromNetwork( String initialization, Comp
167171 INDArray output = net .rnnTimeStep (initializationInput )[0 ];
168172 output = output .tensorAlongDimension ((int )output .size (2 )-1 ,1 ,0 ); //Gets the last time step output
169173
170- for ( int i = 0 ; i < charactersToSample ; i ++ ){
174+ for ( int i = 0 ; i < charactersToSample ; i ++ ){
171175 //Set up next input (single time step) by sampling from previous output
172176 INDArray nextInput = Nd4j .zeros (numSamples ,iter .inputColumns ());
173177 //Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input
174178 for ( int s =0 ; s <numSamples ; s ++ ){
175179 double [] outputProbDistribution = new double [iter .totalOutcomes ()];
176- for ( int j = 0 ; j < outputProbDistribution .length ; j ++ ) outputProbDistribution [j ] = output .getDouble (s ,j );
180+ for ( int j = 0 ; j < outputProbDistribution .length ; j ++) outputProbDistribution [j ] = output .getDouble (s ,j );
177181 int sampledCharacterIdx = GenerateTxtModel .sampleFromDistribution (outputProbDistribution ,rng );
178182
179183 nextInput .putScalar (new int []{s ,sampledCharacterIdx }, 1.0f ); //Prepare next time step input
0 commit comments