1616from ..audio .signal import smooth as smooth_signal
1717from ..ml .nn import average_predictions
1818from ..processors import (OnlineProcessor , ParallelProcessor , Processor ,
19- SequentialProcessor )
19+ SequentialProcessor , BufferProcessor )
2020
2121
2222# classes for tracking (down-)beats with RNNs
@@ -324,7 +324,7 @@ def recursive(position):
324324
325325
326326# classes for detecting/tracking of beat inside a beat activation function
327- class BeatTrackingProcessor (Processor ):
327+ class BeatTrackingProcessor (OnlineProcessor ):
328328 """
329329 Track the beats according to previously determined (local) tempo by
330330 iteratively aligning them around the estimated position [1]_.
@@ -337,7 +337,11 @@ class BeatTrackingProcessor(Processor):
337337 next beat.
338338 look_ahead : float, optional
339339 Look `look_ahead` seconds in both directions to determine the local
340- tempo and align the beats accordingly.
340+ tempo and align the beats accordingly. For online only look into the
341+ past.
342+ threshold : float, optional
343+ Only accept activations as beat which exceed that threshold.
344+ Currently only available in online mode.
341345 tempo_estimator : :class:`TempoEstimationProcessor`, optional
342346 Use this processor to estimate the (local) tempo. If 'None' a default
343347 tempo estimator will be created and used.
@@ -391,22 +395,42 @@ class BeatTrackingProcessor(Processor):
391395 """
392396 LOOK_ASIDE = 0.2
393397 LOOK_AHEAD = 10.
398+ THRESHOLD = 0.1
394399
395- def __init__ (self , look_aside = LOOK_ASIDE , look_ahead = LOOK_AHEAD , fps = None ,
396- tempo_estimator = None , ** kwargs ):
400+ def __init__ (self , look_aside = LOOK_ASIDE , look_ahead = LOOK_AHEAD ,
401+ threshold = THRESHOLD , tempo_estimator = None , fps = None ,
402+ online = False , ** kwargs ):
403+ # pylint: disable=unused-argument
404+ super (BeatTrackingProcessor , self ).__init__ (online = online )
397405 # save variables
398406 self .look_aside = look_aside
399407 self .look_ahead = look_ahead
408+ self .threshold = threshold
400409 self .fps = fps
401410 # tempo estimator
402411 if tempo_estimator is None :
403412 # import the TempoEstimation here otherwise we have a loop
404413 from .tempo import TempoEstimationProcessor
405414 # create default tempo estimator
406- tempo_estimator = TempoEstimationProcessor (fps = fps , ** kwargs )
415+ tempo_estimator = TempoEstimationProcessor (fps = fps , online = online ,
416+ ** kwargs )
407417 self .tempo_estimator = tempo_estimator
418+ if self .online :
419+ self .visualize = kwargs .get ('verbose' , False )
420+ self .buffer = BufferProcessor (int (look_ahead * self .fps ))
421+ self .counter = 0
422+ self .beat_counter = 0
423+ self .last_beat = 0
408424
409- def process (self , activations , ** kwargs ):
425+ def reset (self ):
426+ """Reset the BeatTrackingProcessor."""
427+ self .tempo_estimator .reset ()
428+ self .buffer .reset ()
429+ self .counter = 0
430+ self .beat_counter = 0
431+ self .last_beat = 0
432+
433+ def process_offline (self , activations , ** kwargs ):
410434 """
411435 Detect the beats in the given activation function.
412436
@@ -476,9 +500,76 @@ def process(self, activations, **kwargs):
476500 # remove beats with negative times and return them
477501 return detections [np .searchsorted (detections , 0 ):]
478502
503+ def process_online (self , activations , reset = True , ** kwargs ):
504+ """
505+ Detect the beats in the given activation function for online mode.
506+
507+ Parameters
508+ ----------
509+ activations : numpy array
510+ Beat activation function.
511+ reset : bool, optional
512+ Reset the BeatTrackingProcessor to its initial state before
513+ processing.
514+
515+ Returns
516+ -------
517+ beats : numpy array
518+ Detected beat positions [seconds].
519+
520+ """
521+ # reset to initial state
522+ if reset :
523+ self .reset ()
524+ beats_ = []
525+ for activation in activations :
526+ # shift buffer and put new activation at end of buffer
527+ buffer = self .buffer (activation )
528+ # update online tempo hypothesis with newest activation
529+ histogram = self .tempo_estimator .interval_histogram (
530+ np .array ([activation ]), reset = reset )
531+ # get the dominant interval
532+ interval = self .tempo_estimator .dominant_interval (histogram )
533+ # compute the current and the next possible beat time
534+ cur_beat = self .counter / float (self .fps )
535+ next_beat = self .last_beat + 60. / self .tempo_estimator .max_bpm
536+ # only detect beats again after at least min_interval frames
537+ detections = []
538+ if cur_beat >= next_beat :
539+ detections = detect_beats (buffer , interval , self .look_aside )
540+ # if a detection falls within the last few frames it may be a beat
541+ # this is done because for every frame the tempo or the detections
542+ # may change and therefore the last beat can easily be missed.
543+ look_back = len (buffer ) - (interval * self .look_aside )
544+ # a beat also has to exceed a certain threshold
545+ if len (detections ) and detections [- 1 ] >= look_back and \
546+ buffer [detections [- 1 ]] > self .threshold :
547+ # append to beats
548+ beats_ .append (cur_beat )
549+ # update last beat
550+ self .last_beat = cur_beat
551+ # visualize beats
552+ if self .visualize :
553+ display = ['' ]
554+ if len (beats_ ) > 0 and beats_ [- 1 ] == cur_beat :
555+ self .beat_counter = 10
556+ if self .beat_counter > 0 :
557+ display .append ('| X ' )
558+ else :
559+ display .append ('| ' )
560+ self .beat_counter -= 1
561+ # display tempo
562+ display .append ('| %5.1f | ' % float (self .fps * 60 / interval ))
563+ sys .stderr .write ('\r %s' % '' .join (display ))
564+ sys .stderr .flush ()
565+ # increase counter
566+ self .counter += 1
567+ # return beat(s)
568+ return np .array (beats_ )
569+
479570 @staticmethod
480571 def add_arguments (parser , look_aside = LOOK_ASIDE ,
481- look_ahead = LOOK_AHEAD ):
572+ look_ahead = LOOK_AHEAD , threshold = THRESHOLD ):
482573 """
483574 Add beat tracking related arguments to an existing parser.
484575
@@ -492,7 +583,11 @@ def add_arguments(parser, look_aside=LOOK_ASIDE,
492583 of the next beat.
493584 look_ahead : float, optional
494585 Look `look_ahead` seconds in both directions to determine the local
495- tempo and align the beats accordingly.
586+ tempo and align the beats accordingly. For online only look into
587+ the past.
588+ threshold : float, optional
589+ Only accept activations as beat which exceed that threshold.
590+ Currently only available in online mode.
496591
497592 Returns
498593 -------
@@ -520,6 +615,12 @@ def add_arguments(parser, look_aside=LOOK_ASIDE,
520615 help = 'look this many seconds in both directions '
521616 'to determine the local tempo and align the '
522617 'beats accordingly [default=%(default).2f]' )
618+ if threshold is not None :
619+ g .add_argument ('--threshold' , action = 'store' , type = float ,
620+ default = threshold ,
621+ help = 'only accept activations as beat which exceed '
622+ 'that threshold (currently only for online) '
623+ '[default=%(default).2f]' )
523624 # return the argument group so it can be modified if needed
524625 return g
525626
@@ -891,7 +992,7 @@ def __init__(self, min_bpm=MIN_BPM, max_bpm=MAX_BPM, num_tempi=NUM_TEMPI,
891992 self .fps = fps
892993 self .min_bpm = min_bpm
893994 self .max_bpm = max_bpm
894- # kepp state in online mode
995+ # keep state in online mode
895996 self .online = online
896997 # TODO: refactor the visualisation stuff
897998 if self .online :
0 commit comments