77
88
99def decode_gaussian_blurred_probs (probs , vmin , vmax , deviation , threshold ):
10- num_bins = probs .shape [- 1 ]
10+ num_bins = int ( probs .shape [- 1 ])
1111 interval = (vmax - vmin ) / (num_bins - 1 )
1212 width = int (3 * deviation / interval ) # 3 * sigma
1313 idx = torch .arange (num_bins , device = probs .device )[None , None , :] # [1, 1, N]
@@ -24,14 +24,17 @@ def decode_gaussian_blurred_probs(probs, vmin, vmax, deviation, threshold):
2424 return values , rest
2525
2626
27- def decode_bounds_to_alignment (bounds ):
27+ def decode_bounds_to_alignment (bounds , use_diff = True ):
2828 bounds_step = bounds .cumsum (dim = 1 ).round ().long ()
29- bounds_inc = torch .diff (
30- bounds_step , dim = 1 , prepend = torch .full (
31- (bounds .shape [0 ], 1 ), fill_value = - 1 ,
32- dtype = bounds_step .dtype , device = bounds_step .device
33- )
34- ) > 0
29+ if use_diff :
30+ bounds_inc = torch .diff (
31+ bounds_step , dim = 1 , prepend = torch .full (
32+ (bounds .shape [0 ], 1 ), fill_value = - 1 ,
33+ dtype = bounds_step .dtype , device = bounds_step .device
34+ )
35+ ) > 0
36+ else :
37+ bounds_inc = F .pad ((bounds_step [:, 1 :] > bounds_step [:, :- 1 ]), [1 , 0 ], value = True )
3538 frame2item = bounds_inc .long ().cumsum (dim = 1 )
3639 return frame2item
3740
@@ -48,25 +51,25 @@ def decode_note_sequence(frame2item, values, masks, threshold=0.5):
4851 b = frame2item .shape [0 ]
4952 space = frame2item .max () + 1
5053
51- item_dur = frame2item .new_zeros (b , space ).scatter_add (
54+ item_dur = frame2item .new_zeros (b , space , dtype = frame2item . dtype ).scatter_add (
5255 1 , frame2item , torch .ones_like (frame2item )
5356 )[:, 1 :]
54- item_unmasked_dur = frame2item .new_zeros (b , space ).scatter_add (
57+ item_unmasked_dur = frame2item .new_zeros (b , space , dtype = frame2item . dtype ).scatter_add (
5558 1 , frame2item , masks .long ()
5659 )[:, 1 :]
5760 item_masks = item_unmasked_dur / item_dur >= threshold
5861
5962 values_quant = values .round ().long ()
60- histogram = frame2item .new_zeros (b , space * 128 ).scatter_add (
63+ histogram = frame2item .new_zeros (b , space * 128 , dtype = frame2item . dtype ).scatter_add (
6164 1 , frame2item * 128 + values_quant , torch .ones_like (frame2item ) * masks
6265 ).unflatten (1 , [space , 128 ])[:, 1 :, :]
63- item_values_center = histogram .argmax (dim = 2 ).to (dtype = values .dtype )
66+ item_values_center = histogram .float (). argmax (dim = 2 ).to (dtype = values .dtype )
6467 values_center = torch .gather (F .pad (item_values_center , [1 , 0 ]), 1 , frame2item )
6568 values_near_center = masks & (values >= values_center - 0.5 ) & (values <= values_center + 0.5 )
66- item_valid_dur = frame2item .new_zeros (b , space ).scatter_add (
69+ item_valid_dur = frame2item .new_zeros (b , space , dtype = frame2item . dtype ).scatter_add (
6770 1 , frame2item , values_near_center .long ()
6871 )[:, 1 :]
69- item_values = values .new_zeros (b , space ).scatter_add (
72+ item_values = values .new_zeros (b , space , dtype = values . dtype ).scatter_add (
7073 1 , frame2item , values * values_near_center
7174 )[:, 1 :] / (item_valid_dur + (item_valid_dur == 0 ))
7275
0 commit comments