Skip to content

Commit bd52ecf

Browse files
StefanSebastian Böck
authored andcommitted
add DrumTranscriptor program
1 parent 6511bf4 commit bd52ecf

File tree

4 files changed

+150
-0
lines changed

4 files changed

+150
-0
lines changed

bin/DrumTranscriptor

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
"""
4+
Drum transcription with a convolutional recurrent neural network (CRNN).
5+
6+
"""
7+
8+
from __future__ import absolute_import, division, print_function
9+
10+
import argparse
11+
12+
from madmom.features import ActivationsProcessor
13+
from madmom.features.drums import CRNNDrumProcessor, DrumPeakPickingProcessor
14+
from madmom.features.notes import write_midi, write_notes
15+
from madmom.processors import IOProcessor, io_arguments
16+
17+
18+
def main():
19+
"""DrumTranscriptor"""
20+
21+
# define parser
22+
p = argparse.ArgumentParser(
23+
formatter_class=argparse.RawDescriptionHelpFormatter, description='''
24+
Drum transcription with a convolutional recurrent neural network (CRNN).
25+
''')
26+
# version
27+
p.add_argument('--version', action='version',
28+
version='DrumTranscriptor.2017')
29+
# input/output arguments
30+
io_arguments(p, output_suffix='.drum_transcriptor.txt')
31+
ActivationsProcessor.add_arguments(p)
32+
# peak picking arguments
33+
DrumPeakPickingProcessor.add_arguments(
34+
p, threshold=0.15, smooth=0, pre_avg=0.1, post_avg=0.01, pre_max=0.02,
35+
post_max=0.01, combine=0.02)
36+
# midi arguments
37+
p.add_argument('--midi', dest='output_format', action='store_const',
38+
const='midi', help='save as MIDI')
39+
40+
# parse arguments
41+
args = p.parse_args()
42+
43+
# set immutable defaults
44+
args.fps = 100
45+
46+
# set the suffix for midi files
47+
if args.output_format == 'midi':
48+
args.output_suffix = '.mid'
49+
50+
# print arguments
51+
if args.verbose:
52+
print(args)
53+
54+
# input processor
55+
if args.load:
56+
# load the activations from file
57+
in_processor = ActivationsProcessor(mode='r', **vars(args))
58+
else:
59+
# use a RNN to predict the notes
60+
in_processor = CRNNDrumProcessor(**vars(args))
61+
62+
# output processor
63+
if args.save:
64+
# save the RNN note activations to file
65+
out_processor = ActivationsProcessor(mode='w', **vars(args))
66+
else:
67+
# perform peak picking on the activation function
68+
peak_picking = DrumPeakPickingProcessor(**vars(args))
69+
# output everything in the right format
70+
if args.output_format is None:
71+
output = write_notes
72+
elif args.output_format == 'midi':
73+
output = write_midi
74+
else:
75+
raise ValueError('unknown output format: %s' % args.output_format)
76+
out_processor = [peak_picking, output]
77+
78+
# create an IOProcessor
79+
processor = IOProcessor(in_processor, out_processor)
80+
81+
# and call the processing function
82+
args.func(processor, **vars(args))
83+
84+
85+
if __name__ == '__main__':
86+
main()
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# drums detected by the DrumTranscriptor program with default settings
2+
0.130 0
3+
0.130 2
4+
0.480 2
5+
0.650 0
6+
0.800 0
7+
0.840 1
8+
0.840 2
9+
1.010 1
10+
1.160 0
11+
1.160 2
12+
1.500 2
13+
1.510 0
14+
1.660 1
15+
1.660 2
16+
1.840 0
17+
1.840 2
18+
2.180 1
19+
2.180 2
20+
2.360 1
21+
2.520 2
22+
2.700 0

tests/test_bin.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,47 @@ def test_run(self):
742742
self.assertTrue(np.allclose(result, self.result, atol=1e-5))
743743

744744

745+
class TestDrumTranscriptorProgram(unittest.TestCase):
746+
def setUp(self):
747+
self.bin = pj(program_path, "DrumTranscriptor")
748+
self.activations = Activations(
749+
pj(ACTIVATIONS_PATH, "sample.drums_crnn.npz"))
750+
self.result = np.loadtxt(
751+
pj(DETECTIONS_PATH, "sample.drum_transcriptor.txt"))
752+
753+
def test_help(self):
754+
self.assertTrue(run_help(self.bin))
755+
756+
def test_binary(self):
757+
# save activations as binary file
758+
run_program([self.bin, '--save', 'single', sample_file,
759+
'-o', tmp_act])
760+
act = Activations(tmp_act)
761+
self.assertTrue(np.allclose(act, self.activations, atol=1e-5))
762+
self.assertEqual(act.fps, self.activations.fps)
763+
# reload from file
764+
run_program([self.bin, '--load', 'single', tmp_act, '-o', tmp_result])
765+
result = np.loadtxt(tmp_result)
766+
self.assertTrue(np.allclose(result, self.result, atol=1e-5))
767+
768+
def test_txt(self):
769+
# save activations as txt file
770+
run_program([self.bin, '--save', '--sep', ' ', 'single',
771+
sample_file, '-o', tmp_act])
772+
act = Activations(tmp_act, sep=' ', fps=100)
773+
self.assertTrue(np.allclose(act, self.activations, atol=1e-5))
774+
# reload from file
775+
run_program([self.bin, '--load', '--sep', ' ', 'single', tmp_act,
776+
'-o', tmp_result])
777+
result = np.loadtxt(tmp_result)
778+
self.assertTrue(np.allclose(result, self.result, atol=1e-5))
779+
780+
def test_run(self):
781+
run_program([self.bin, 'single', sample_file, '-o', tmp_result])
782+
result = np.loadtxt(tmp_result)
783+
self.assertTrue(np.allclose(result, self.result, atol=1e-5))
784+
785+
745786
class TestSpectralOnsetDetectionProgram(unittest.TestCase):
746787
def setUp(self):
747788
self.bin = pj(program_path, "SpectralOnsetDetection")

tests/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
pj(DETECTIONS_PATH, 'sample.cnn_onset_detector.txt'),
6666
pj(DETECTIONS_PATH, 'sample.complex_flux.txt'),
6767
pj(DETECTIONS_PATH, 'sample.crf_beat_detector.txt'),
68+
pj(DETECTIONS_PATH, 'sample.drum_transcriptor.txt'),
6869
pj(DETECTIONS_PATH, 'sample.dbn_beat_tracker.txt'),
6970
pj(DETECTIONS_PATH, 'sample.dbn_downbeat_tracker.txt'),
7071
pj(DETECTIONS_PATH, 'sample.dc_chord_recognition.txt'),

0 commit comments

Comments
 (0)