1+ """A degrader object can be used to easily degrade data points on the fly
2+ according to some given parameters."""
3+ import json
4+ import numpy as np
5+ import warnings
6+
7+ import mdtk .degradations as degs
8+
9+ class Degrader ():
10+ """A Degrade object can be used to easily degrade musical excerpts
11+ on the fly."""
12+
13+ def __init__ (self , seed = None , degradations = list (degs .DEGRADATIONS .keys ()),
14+ degradation_dist = np .ones (len (degs .DEGRADATIONS )),
15+ clean_prop = 1 / (len (degs .DEGRADATIONS ) + 1 ), config = None ):
16+ """
17+ Create a new degrader with the given parameters.
18+
19+ Parameters
20+ ----------
21+ seed : int
22+ A random seed for numpy.
23+
24+ degradations : list(string)
25+ A list of the names of the degradations to use (and in what order
26+ to label them).
27+
28+ degradation_dist : list(float)
29+ A list of the probability of each degradation given in
30+ degradations. This list will be normalized to sum to 1.
31+
32+ clean_prop : float
33+ The proportion of degrade calls that should return clean excerpts.
34+
35+ config : string
36+ The path of a json config file (created by measure_errors.py).
37+ If given, degradations, degradation_dist, and clean_prop will
38+ all be overwritten by the values in the json file.
39+ """
40+ if seed is not None :
41+ np .random .seed (seed )
42+
43+ # Load config
44+ if config is not None :
45+ with open (config , 'r' ) as file :
46+ config = json .load (file )
47+
48+ if 'degradation_dist' in config :
49+ degradation_dist = np .array (config ['degradation_dist' ])
50+ degradations = list (degs .DEGRADATIONS .keys ())
51+ if 'clean_prop' in config :
52+ clean_prop = config ['clean_prop' ]
53+
54+ # Check arg validity
55+ assert len (degradation_dist ) == len (degradations ), (
56+ "Given degradation_dist is not the same length as degradations:"
57+ f"\n len({ degradation_dist } ) != len({ degradations } )"
58+ )
59+ assert min (degradation_dist ) >= 0 , ("degradation_dist values must "
60+ "not be negative." )
61+ assert sum (degradation_dist ) > 0 , ("Some degradation_dist value "
62+ "must be positive." )
63+ assert 0 <= clean_prop <= 1 , ("clean_prop must be between 0 and 1 "
64+ "(inclusive)." )
65+
66+ self .degradations = degradations
67+ self .degradation_dist = degradation_dist
68+ self .clean_prop = clean_prop
69+ self .failed = np .zeros (len (degradations ))
70+
71+
72+ def degrade (self , note_df ):
73+ """
74+ Degrade the given note_df.
75+
76+ Parameters
77+ ----------
78+ note_df : pd.DataFrame
79+ A note_df to degrade.
80+
81+ Returns
82+ -------
83+ degraded_df : pd.DataFrame
84+ A degraded version of the given note_df. If self.clean_prop > 0,
85+ this can be a copy of the given note_df.
86+
87+ deg_label : int
88+ The label of the degradation that was performed. 0 means none,
89+ and larger numbers mean the degradation
90+ "self.degradations[deg_label+1]" was performed.
91+ """
92+ if self .clean_prop > 0 and np .random .rand () <= self .clean_prop :
93+ return note_df .copy (), 0
94+
95+ degraded_df = None
96+ this_deg_dist = self .degradation_dist .copy ()
97+ this_failed = self .failed .copy ()
98+
99+ # First, sample from failed degradations
100+ while np .any (this_failed > 0 ):
101+ # Select a degradation proportional to how many have failed
102+ deg_index = np .random .choice (
103+ len (self .degradations ),
104+ p = this_failed / np .sum (this_failed )
105+ )
106+ deg_fun = degs .DEGRADATIONS [self .degradations [deg_index ]]
107+
108+ # Try to degrade
109+ with warnings .catch_warnings ():
110+ warnings .simplefilter ("ignore" )
111+ degraded_df = deg_fun (note_df )
112+
113+ # Check for success!
114+ if degraded_df is not None :
115+ self .failed [deg_index ] -= 1
116+ return degraded_df , deg_index + 1
117+
118+ # Degradation failed -- 0 out this deg and continue
119+ this_failed [deg_index ] = 0
120+
121+ # No degradations have remaining failures. Draw from standard dist
122+ while np .any (this_deg_dist > 0 ):
123+ # Select a degradation proportional to the distribution
124+ deg_index = np .random .choice (
125+ len (self .degradations ),
126+ p = this_deg_dist / np .sum (this_deg_dist )
127+ )
128+ # This deg would have already failed in the above loop.
129+ # But we want to sample it and count it as another failure.
130+ if self .failed [deg_index ] > 0 :
131+ self .failed [deg_index ] += 1
132+ continue
133+ deg_fun = degs .DEGRADATIONS [self .degradations [deg_index ]]
134+
135+ # Try to degrade
136+ with warnings .catch_warnings ():
137+ warnings .simplefilter ("ignore" )
138+ degraded_df = deg_fun (note_df )
139+
140+ # Check for success!
141+ if degraded_df is not None :
142+ return degraded_df , deg_index + 1
143+
144+ # Degradation failed -- add 1 to failure and continue
145+ self .failed [deg_index ] += 1
146+
147+ # Here, all degradations (with dist > 0) failed
148+ return note_df .copy (), 0
149+
0 commit comments