Skip to content

Commit a4e8a39

Browse files
Basic online BeatTracker
1 parent a534cb7 commit a4e8a39

File tree

4 files changed

+140
-11
lines changed

4 files changed

+140
-11
lines changed

bin/BeatTracker

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def main():
5959
# version
6060
p.add_argument('--version', action='version', version='BeatTracker.2016')
6161
# input/output arguments
62-
io_arguments(p, output_suffix='.beats.txt')
62+
io_arguments(p, output_suffix='.beats.txt', online=True)
6363
ActivationsProcessor.add_arguments(p)
6464
# signal processing arguments
6565
SignalProcessor.add_arguments(p, norm=False, gain=0)

madmom/features/beats.py

Lines changed: 111 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..audio.signal import smooth as smooth_signal
1717
from ..ml.nn import average_predictions
1818
from ..processors import (OnlineProcessor, ParallelProcessor, Processor,
19-
SequentialProcessor)
19+
SequentialProcessor, BufferProcessor)
2020

2121

2222
# classes for tracking (down-)beats with RNNs
@@ -324,7 +324,7 @@ def recursive(position):
324324

325325

326326
# classes for detecting/tracking of beat inside a beat activation function
327-
class BeatTrackingProcessor(Processor):
327+
class BeatTrackingProcessor(OnlineProcessor):
328328
"""
329329
Track the beats according to previously determined (local) tempo by
330330
iteratively aligning them around the estimated position [1]_.
@@ -337,7 +337,11 @@ class BeatTrackingProcessor(Processor):
337337
next beat.
338338
look_ahead : float, optional
339339
Look `look_ahead` seconds in both directions to determine the local
340-
tempo and align the beats accordingly.
340+
tempo and align the beats accordingly. For online only look into the
341+
past.
342+
threshold : float, optional
343+
Only accept activations as beat which exceed that threshold.
344+
Currently only available in online mode.
341345
tempo_estimator : :class:`TempoEstimationProcessor`, optional
342346
Use this processor to estimate the (local) tempo. If 'None' a default
343347
tempo estimator will be created and used.
@@ -391,22 +395,42 @@ class BeatTrackingProcessor(Processor):
391395
"""
392396
LOOK_ASIDE = 0.2
393397
LOOK_AHEAD = 10.
398+
THRESHOLD = 0.1
394399

395-
def __init__(self, look_aside=LOOK_ASIDE, look_ahead=LOOK_AHEAD, fps=None,
396-
tempo_estimator=None, **kwargs):
400+
def __init__(self, look_aside=LOOK_ASIDE, look_ahead=LOOK_AHEAD,
401+
threshold=THRESHOLD, tempo_estimator=None, fps=None,
402+
online=False, **kwargs):
403+
# pylint: disable=unused-argument
404+
super(BeatTrackingProcessor, self).__init__(online=online)
397405
# save variables
398406
self.look_aside = look_aside
399407
self.look_ahead = look_ahead
408+
self.threshold = threshold
400409
self.fps = fps
401410
# tempo estimator
402411
if tempo_estimator is None:
403412
# import the TempoEstimation here otherwise we have a loop
404413
from .tempo import TempoEstimationProcessor
405414
# create default tempo estimator
406-
tempo_estimator = TempoEstimationProcessor(fps=fps, **kwargs)
415+
tempo_estimator = TempoEstimationProcessor(fps=fps, online=online,
416+
**kwargs)
407417
self.tempo_estimator = tempo_estimator
418+
if self.online:
419+
self.visualize = kwargs.get('verbose', False)
420+
self.buffer = BufferProcessor(int(look_ahead * self.fps))
421+
self.counter = 0
422+
self.beat_counter = 0
423+
self.last_beat = 0
408424

409-
def process(self, activations, **kwargs):
425+
def reset(self):
426+
"""Reset the BeatTrackingProcessor."""
427+
self.tempo_estimator.reset()
428+
self.buffer.reset()
429+
self.counter = 0
430+
self.beat_counter = 0
431+
self.last_beat = 0
432+
433+
def process_offline(self, activations, **kwargs):
410434
"""
411435
Detect the beats in the given activation function.
412436
@@ -476,9 +500,76 @@ def process(self, activations, **kwargs):
476500
# remove beats with negative times and return them
477501
return detections[np.searchsorted(detections, 0):]
478502

503+
def process_online(self, activations, reset=True, **kwargs):
504+
"""
505+
Detect the beats in the given activation function for online mode.
506+
507+
Parameters
508+
----------
509+
activations : numpy array
510+
Beat activation function.
511+
reset : bool, optional
512+
Reset the BeatTrackingProcessor to its initial state before
513+
processing.
514+
515+
Returns
516+
-------
517+
beats : numpy array
518+
Detected beat positions [seconds].
519+
520+
"""
521+
# reset to initial state
522+
if reset:
523+
self.reset()
524+
beats_ = []
525+
for activation in activations:
526+
# shift buffer and put new activation at end of buffer
527+
buffer = self.buffer(activation)
528+
# update online tempo hypothesis with newest activation
529+
histogram = self.tempo_estimator.interval_histogram(
530+
np.array([activation]), reset=reset)
531+
# get the dominant interval
532+
interval = self.tempo_estimator.dominant_interval(histogram)
533+
# compute the current and the next possible beat time
534+
cur_beat = self.counter / float(self.fps)
535+
next_beat = self.last_beat + 60. / self.tempo_estimator.max_bpm
536+
# only detect beats again after at least min_interval frames
537+
detections = []
538+
if cur_beat >= next_beat:
539+
detections = detect_beats(buffer, interval, self.look_aside)
540+
# if a detection falls within the last few frames it may be a beat
541+
# this is done because for every frame the tempo or the detections
542+
# may change and therefore the last beat can easily be missed.
543+
look_back = len(buffer) - (interval * self.look_aside)
544+
# a beat also has to exceed a certain threshold
545+
if len(detections) and detections[-1] >= look_back and \
546+
buffer[detections[-1]] > self.threshold:
547+
# append to beats
548+
beats_.append(cur_beat)
549+
# update last beat
550+
self.last_beat = cur_beat
551+
# visualize beats
552+
if self.visualize:
553+
display = ['']
554+
if len(beats_) > 0 and beats_[-1] == cur_beat:
555+
self.beat_counter = 10
556+
if self.beat_counter > 0:
557+
display.append('| X ')
558+
else:
559+
display.append('| ')
560+
self.beat_counter -= 1
561+
# display tempo
562+
display.append('| %5.1f | ' % float(self.fps * 60 / interval))
563+
sys.stderr.write('\r%s' % ''.join(display))
564+
sys.stderr.flush()
565+
# increase counter
566+
self.counter += 1
567+
# return beat(s)
568+
return np.array(beats_)
569+
479570
@staticmethod
480571
def add_arguments(parser, look_aside=LOOK_ASIDE,
481-
look_ahead=LOOK_AHEAD):
572+
look_ahead=LOOK_AHEAD, threshold=THRESHOLD):
482573
"""
483574
Add beat tracking related arguments to an existing parser.
484575
@@ -492,7 +583,11 @@ def add_arguments(parser, look_aside=LOOK_ASIDE,
492583
of the next beat.
493584
look_ahead : float, optional
494585
Look `look_ahead` seconds in both directions to determine the local
495-
tempo and align the beats accordingly.
586+
tempo and align the beats accordingly. For online only look into
587+
the past.
588+
threshold : float, optional
589+
Only accept activations as beat which exceed that threshold.
590+
Currently only available in online mode.
496591
497592
Returns
498593
-------
@@ -520,6 +615,12 @@ def add_arguments(parser, look_aside=LOOK_ASIDE,
520615
help='look this many seconds in both directions '
521616
'to determine the local tempo and align the '
522617
'beats accordingly [default=%(default).2f]')
618+
if threshold is not None:
619+
g.add_argument('--threshold', action='store', type=float,
620+
default=threshold,
621+
help='only accept activations as beat which exceed '
622+
'that threshold (currently only for online) '
623+
'[default=%(default).2f]')
523624
# return the argument group so it can be modified if needed
524625
return g
525626

@@ -891,7 +992,7 @@ def __init__(self, min_bpm=MIN_BPM, max_bpm=MAX_BPM, num_tempi=NUM_TEMPI,
891992
self.fps = fps
892993
self.min_bpm = min_bpm
893994
self.max_bpm = max_bpm
894-
# kepp state in online mode
995+
# keep state in online mode
895996
self.online = online
896997
# TODO: refactor the visualisation stuff
897998
if self.online:

tests/test_bin.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def setUp(self):
211211
pj(ACTIVATIONS_PATH, "sample.beats_blstm.npz"))
212212
self.result = np.loadtxt(
213213
pj(DETECTIONS_PATH, "sample.beat_tracker.txt"))
214+
self.online_results = [0.78, 1.14, 1.48, 1.84, 2.18, 2.51]
214215

215216
def test_help(self):
216217
self.assertTrue(run_help(self.bin))
@@ -241,6 +242,14 @@ def test_run(self):
241242
result = np.loadtxt(tmp_result)
242243
self.assertTrue(np.allclose(result, self.result, atol=1e-5))
243244

245+
def test_online(self):
246+
run_online(self.bin, sample_file, tmp_result)
247+
result = np.loadtxt(tmp_result)
248+
self.assertTrue(np.allclose(result, self.online_results))
249+
run_single(self.bin, sample_file, tmp_result, online=True)
250+
result = np.loadtxt(tmp_result)
251+
self.assertTrue(np.allclose(result, self.online_results))
252+
244253

245254
class TestCNNChordRecognition(unittest.TestCase):
246255
def setUp(self):

tests/test_features_beats.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,25 @@ def test_process(self):
6666
self.assertTrue(np.allclose(beats, [0.11, 0.45, 0.79, 1.13, 1.47,
6767
1.81, 2.15, 2.49]))
6868

69+
def test_process_online(self):
70+
processor = BeatTrackingProcessor(fps=sample_lstm_act.fps,
71+
online=True)
72+
# compute the beats at once
73+
beats = processor.process_online(sample_lstm_act, reset=False)
74+
self.assertTrue(np.allclose(beats, [0.68, 1.14, 1.48, 1.84, 2.18,
75+
2.51]))
76+
# compute the beats framewise
77+
processor.reset()
78+
beats = [processor.process_online(np.atleast_2d(act), reset=False)
79+
for act in sample_lstm_act]
80+
self.assertTrue(np.allclose(np.nonzero(beats),
81+
[68, 114, 148, 184, 218, 251]))
82+
# without resetting results are different
83+
beats = [processor.process_online(np.atleast_2d(act), reset=False)
84+
for act in sample_lstm_act]
85+
self.assertTrue(np.allclose(np.nonzero(beats), [5, 148, 184, 217,
86+
251]))
87+
6988

7089
class TestBeatDetectionProcessorClass(unittest.TestCase):
7190

0 commit comments

Comments
 (0)