11/*******************************************************************************
2- *
3- *
42 *
53 * This program and the accompanying materials are made available under the
64 * terms of the Apache License, Version 2.0 which is available at
1715 * SPDX-License-Identifier: Apache-2.0
1816 ******************************************************************************/
1917
20- package org .deeplearning4j .examples .wip . advanced .modelling .melodl4j ;
18+ package org .deeplearning4j .examples .advanced .modelling . charmodelling .melodl4j ;
2119
2220import org .apache .commons .io .FileUtils ;
2321import org .deeplearning4j .examples .advanced .modelling .charmodelling .utils .CharacterIterator ;
3129import org .deeplearning4j .nn .weights .WeightInit ;
3230import org .deeplearning4j .optimize .listeners .ScoreIterationListener ;
3331import org .deeplearning4j .util .ModelSerializer ;
32+ import org .nd4j .common .util .ArchiveUtils ;
3433import org .nd4j .linalg .activations .Activation ;
3534import org .nd4j .linalg .api .ndarray .INDArray ;
3635import org .nd4j .linalg .dataset .DataSet ;
3736import org .nd4j .linalg .factory .Nd4j ;
37+ import org .nd4j .linalg .learning .config .Adam ;
3838import org .nd4j .linalg .learning .config .RmsProp ;
3939import org .nd4j .linalg .lossfunctions .LossFunctions ;
4040
41+ import javax .sound .midi .InvalidMidiDataException ;
4142import java .io .*;
4243import java .net .URL ;
4344import java .nio .charset .Charset ;
45+ import java .nio .file .Files ;
46+ import java .nio .file .Path ;
47+ import java .text .NumberFormat ;
4448import java .util .ArrayList ;
4549import java .util .List ;
4650import java .util .Random ;
51+ import java .util .zip .ZipEntry ;
52+ import java .util .zip .ZipInputStream ;
4753
4854/**
4955 * LSTM Symbolic melody modelling example, to compose music from symbolic melodies extracted from MIDI.
50- * Based closely on LSTMCharModellingExample.java.
56+ * LSTM logic is based closely on LSTMCharModellingExample.java.
5157 * See the README file in this directory for documentation.
5258 *
5359 * @author Alex Black, Donald A. Smith.
5460 */
5561public class MelodyModelingExample {
56- final static String inputSymbolicMelodiesFilename = "bach-melodies-input.txt" ;
57- // Examples: bach-melodies-input.txt, beatles-melodies-input.txt , pop-melodies-input.txt (large)
62+ // If you want to change the MIDI files used in learning, create a zip file containing your MIDI
63+ // files and replace the following path. For example, you might use something like:
64+ //final static String midiFileZipFileUrlPath = "file:d:/music/midi/classical-midi.zip";
65+ final static String midiFileZipFileUrlPath = "http://waliberals.org/truthsite/music/bach-midi.zip" ;
5866
59- final static String tmpDir = System .getProperty ("java.io.tmpdir" );
67+ // For example "bach-midi.txt"
68+ final static String inputSymbolicMelodiesFilename = getMelodiesFileNameFromURLPath (midiFileZipFileUrlPath );
6069
61- final static String symbolicMelodiesInputFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename ; // Point to melodies created by MidiMelodyExtractor.java
70+ // Examples: bach-melodies-input.txt, beatles-melodies-input.txt , pop-melodies-input.txt (large)
71+ final static String tmpDir = System .getProperty ("java.io.tmpdir" );
72+ final static String inputSymbolicMelodiesFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename ; // Point to melodies created by MidiMelodyExtractor.java
6273 final static String composedMelodiesOutputFilePath = tmpDir + "/composition.txt" ; // You can listen to these melodies by running PlayMelodyStrings.java against this file.
6374
6475 //final static String symbolicMelodiesInputFilePath = "D:/tmp/bach-melodies.txt";
6576 //final static String composedMelodiesOutputFilePath = tmpDir + "/bach-composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file.
66-
77+ final static NumberFormat numberFormat = NumberFormat .getNumberInstance ();
78+ static {
79+ numberFormat .setMinimumFractionDigits (1 );
80+ numberFormat .setMaximumFractionDigits (1 );
81+ }
6782 //....
6883 public static void main (String [] args ) throws Exception {
6984 String loadNetworkPath = null ; //"/tmp/MelodyModel-bach.zip"; //null;
@@ -73,6 +88,8 @@ public static void main(String[] args) throws Exception {
7388 generationInitialization = args [1 ];
7489 }
7590
91+ makeMidiStringFileIfNecessary ();
92+
7693 int lstmLayerSize = 200 ; //Number of units in each LSTM layer
7794 int miniBatchSize = 32 ; //Size of mini batch to use when training
7895 int exampleLength = 500 ; //1000; //Length of each training example sequence to use.
@@ -107,9 +124,10 @@ public static void main(String[] args) throws Exception {
107124
108125 //Set up network configuration:
109126 MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
110- .updater (new RmsProp (0.1 ))
111- .seed (12345 )
112- .l2 (0.001 )
127+ //.updater(new RmsProp(0.1))
128+ .updater (new Adam (0.005 ))
129+ .seed (System .currentTimeMillis ()) // So each run generates new melodies
130+ .l2 (0.0001 )
113131 .weightInit (WeightInit .XAVIER )
114132 .list ()
115133 .layer (0 , new LSTM .Builder ().nIn (iter .inputColumns ()).nOut (lstmLayerSize )
@@ -123,7 +141,6 @@ public static void main(String[] args) throws Exception {
123141 .backpropType (BackpropType .TruncatedBPTT ).tBPTTForwardLength (tbpttLength ).tBPTTBackwardLength (tbpttLength )
124142 .build ();
125143
126-
127144 learn (miniBatchSize , exampleLength , numEpochs , generateSamplesEveryNMinibatches , nSamplesToGenerate , nCharactersToSample , generationInitialization , rng , startTime , iter , conf );
128145 }
129146
@@ -154,6 +171,7 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
154171 // order, so that the best melodies are at the start of the file.
155172 //Do training, and then generate and print samples from network
156173 int miniBatchNumber = 0 ;
174+ long lastTime = System .currentTimeMillis ();
157175 for (int epoch = 0 ; epoch < numEpochs ; epoch ++) {
158176 System .out .println ("Starting epoch " + epoch );
159177 while (iter .hasNext ()) {
@@ -176,12 +194,19 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
176194 }
177195 }
178196 iter .reset (); //Reset iterator for another epoch
197+ final double secondsForEpoch = 0.001 * (System .currentTimeMillis () - startTime );
198+ final long now = System .currentTimeMillis ();
179199 if (melodies .size () > 0 ) {
180200 String melody = melodies .get (melodies .size () - 1 );
181201 int seconds = 25 ;
182202 System .out .println ("\n First " + seconds + " seconds of " + melody );
183203 PlayMelodyStrings .playMelody (melody , seconds );
184204 }
205+ double seconds = 0.001 *(now - lastTime );
206+ lastTime = now ;
207+ System .out .println ("\n Epoch " + epoch + " time in seconds: " + numberFormat .format (seconds ));
208+ // 531.9 for GPU GTX 1070
209+ // 821.4 for CPU i7-6700K @ 4GHZ
185210 }
186211 int indexOfLastPeriod = inputSymbolicMelodiesFilename .lastIndexOf ('.' );
187212 String saveFileName = inputSymbolicMelodiesFilename .substring (0 , indexOfLastPeriod > 0 ? indexOfLastPeriod : inputSymbolicMelodiesFilename .length ());
@@ -193,42 +218,82 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
193218 printWriter .println (melodies .get (i ));
194219 }
195220 printWriter .close ();
196- double seconds = 0.001 * (System .currentTimeMillis () - startTime );
197221
198- System .out .println ("\n \n Example complete in " + seconds + " seconds" );
199222 System .exit (0 );
200223 }
201224
202- public static void makeSureFileIsInTmpDir (String filename ) {
225+ public static File makeSureFileIsInTmpDir (String urlString ) throws IOException {
226+ final URL url = new URL (urlString );
227+ final String filename = urlString .substring (1 +urlString .lastIndexOf ("/" ));
203228 final File f = new File (tmpDir + "/" + filename );
204- if (!f .exists ()) {
205- URL url = null ;
206- try {
207- url = new URL ("http://truthsite.org/music/" + filename );
208- FileUtils .copyURLToFile (url , f );
209- } catch (Exception exc ) {
210- System .err .println ("Error copying " + url + " to " + f );
211- throw new RuntimeException (exc );
212- }
229+ if (f .exists ()) {
230+ System .out .println ("Using existing " + f .getAbsolutePath ());
231+ } else {
232+ FileUtils .copyURLToFile (url , f );
213233 if (!f .exists ()) {
214234 throw new RuntimeException (f .getAbsolutePath () + " does not exist" );
215235 }
216236 System .out .println ("File downloaded to " + f .getAbsolutePath ());
217- } else {
218- System .out .println ("Using existing text file at " + f .getAbsolutePath ());
219237 }
238+ return f ;
220239 }
221240
241+ //https://stackoverflow.com/questions/10633595/java-zip-how-to-unzip-folder
242+ public static void unzip (File zipFile , File targetDirFile ) throws IOException {
243+ InputStream is = new FileInputStream (zipFile );
244+ Path targetDir = targetDirFile .toPath ();
245+ targetDir = targetDir .toAbsolutePath ();
246+ try (ZipInputStream zipIn = new ZipInputStream (is )) {
247+ for (ZipEntry ze ; (ze = zipIn .getNextEntry ()) != null ; ) {
248+ Path resolvedPath = targetDir .resolve (ze .getName ()).normalize ();
249+ if (!resolvedPath .startsWith (targetDir )) {
250+ // see: https://snyk.io/research/zip-slip-vulnerability
251+ throw new RuntimeException ("Entry with an illegal path: "
252+ + ze .getName ());
253+ }
254+ if (ze .isDirectory ()) {
255+ Files .createDirectories (resolvedPath );
256+ } else {
257+ Files .createDirectories (resolvedPath .getParent ());
258+ Files .copy (zipIn , resolvedPath );
259+ }
260+ }
261+ }
262+ is .close ();
263+ }
264+ private static void makeMidiStringFileIfNecessary () throws IOException , InvalidMidiDataException {
265+ final File inputMelodiesFile = new File (inputSymbolicMelodiesFilePath );
266+ if (inputMelodiesFile .exists () && inputMelodiesFile .length ()>1000 ) {
267+ System .out .println ("Using existing " + inputSymbolicMelodiesFilePath );
268+ return ;
269+ }
270+ final File midiZipFile = makeSureFileIsInTmpDir (midiFileZipFileUrlPath );
271+ final String midiZipFileName = midiZipFile .getName ();
272+ final String midiZipFileNameWithoutSuffix = midiZipFileName .substring (0 ,midiZipFileName .lastIndexOf ("." ));
273+ final File outputDirectoryFile = new File (tmpDir ,midiZipFileNameWithoutSuffix );
274+ final String outputDirectoryPath = outputDirectoryFile .getAbsolutePath ();
275+ if (!outputDirectoryFile .exists ()) {
276+ outputDirectoryFile .mkdir ();
277+ }
278+ if (!outputDirectoryFile .exists () || !outputDirectoryFile .isDirectory ()) {
279+ throw new IllegalStateException (outputDirectoryFile + " is not a directory or can't be created" );
280+ }
281+ final PrintStream printStream = new PrintStream (inputSymbolicMelodiesFilePath );
282+ System .out .println ("Unzipping " + midiZipFile .getAbsolutePath () + " to " + outputDirectoryPath );
283+ unzip (midiZipFile , outputDirectoryFile );
284+ System .out .println ("Extracted " + midiZipFile .getAbsolutePath () + " to " + outputDirectoryPath );
285+ MidiMelodyExtractor .processDirectoryAndWriteMelodyFile (outputDirectoryFile ,inputMelodiesFile );
286+ printStream .close ();
287+ }
222288 /**
223289 * Sets up and return a simple DataSetIterator that does vectorization based on the melody sample.
224290 *
225291 * @param miniBatchSize Number of text segments in each training mini-batch
226292 * @param sequenceLength Number of characters in each text segment.
227293 */
228294 public static CharacterIterator getMidiIterator (int miniBatchSize , int sequenceLength ) throws Exception {
229- makeSureFileIsInTmpDir (inputSymbolicMelodiesFilename );
230295 final char [] validCharacters = MelodyStrings .allValidCharacters .toCharArray (); //Which characters are allowed? Others will be removed
231- return new CharacterIterator (symbolicMelodiesInputFilePath , Charset .forName ("UTF-8" ),
296+ return new CharacterIterator (inputSymbolicMelodiesFilePath , Charset .forName ("UTF-8" ),
232297 miniBatchSize , sequenceLength , validCharacters , new Random (12345 ), MelodyStrings .COMMENT_STRING );
233298 }
234299
@@ -312,5 +377,13 @@ public static int sampleFromDistribution(double[] distribution, Random rng) {
312377 //Should be extremely unlikely to happen if distribution is a valid probability distribution
313378 throw new IllegalArgumentException ("Distribution is invalid? d=" + d + ", sum=" + sum );
314379 }
380+ private static String getMelodiesFileNameFromURLPath (String midiFileZipFileUrlPath ) {
381+ if (!(midiFileZipFileUrlPath .endsWith (".zip" ) || midiFileZipFileUrlPath .endsWith (".ZIP" ))) {
382+ throw new IllegalStateException ("zipFilePath must end with .zip" );
383+ }
384+ midiFileZipFileUrlPath = midiFileZipFileUrlPath .replace ('\\' ,'/' );
385+ String fileName = midiFileZipFileUrlPath .substring (midiFileZipFileUrlPath .lastIndexOf ("/" ) + 1 );
386+ return fileName + ".txt" ;
387+ }
315388}
316389
0 commit comments