Skip to content

Commit 17393c4

Browse files
authored
Merge pull request #119 from JamesOwers/amt
Amt
2 parents 993e82a + 2054062 commit 17393c4

File tree

10 files changed

+1146
-20
lines changed

10 files changed

+1146
-20
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Tools to generate datasets of Altered and Corrupted MIDI Excerpts -`ACME`
33
datasets.
44

5-
The accompanying paper (submitted to ICASSP, available upon request)
5+
The accompanying paper (submitted to ISMIR)
66
"Symbolic Music Correction using The MIDI Degradation Toolkit" describes the
77
toolkit and its motivation in detail. For instructions to reproduce the results
88
from the paper, see [`./baselines/README.md`](./baselines/README.md).
@@ -38,6 +38,8 @@ Some highlights include:
3838
data for use
3939
* [`mdtk.degradations`](./mdtk/degradations.py) - functions to alter midi data
4040
e.g. `pitch_shift` or `time_shift`
41+
* [`mdtk.degrader`](./mdtk/degrader.py) - Degrader class that can be used to
42+
degrade data points randomly on the fly
4143
* [`mdtk.eval`](./mdtk/eval.py) - functions for evaluating model performance
4244
on each task, given a list of outputs and targets
4345
* [`mdtk.formatters`](./mdtk/formatters.py) - functions converting between
@@ -71,3 +73,10 @@ pip install . # use pip install -e . for dev mode if you want to edit files
7173

7274
To generate an `ACME` dataset simply install the package with instructions
7375
above and run `./make_dataset.py`.
76+
77+
For usage instructions for the `measure_errors.py` script, run
78+
`python measure_errors.py -h` you should create a directory of transcriptions
79+
and a directory of ground truth files (in mid or csv format). The ground truth
80+
and corresponding transcription should be named the exact same thing.
81+
82+
See `measure_errors_example.ipynb` for an example of the script's usage.

make_dataset.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def parse_args(args_input=None):
8787
parser.add_argument('-i', '--input-dir', type=str, default=default_indir,
8888
help='the directory to store the preprocessed '
8989
'downloaded data to.')
90+
parser.add_argument('--config', default=None, help='Load a json config '
91+
'file, in the format created by measure_errors.py. '
92+
'This will override --degradations, --degradation-'
93+
'dist, and --clean-prop.')
9094
parser.add_argument('--formats', metavar='format', help='Create '
9195
'custom versions of the acme data for easier loading '
9296
'with our provided pytorch Dataset classes. Choices are'
@@ -175,6 +179,7 @@ def parse_args(args_input=None):
175179
seed = ARGS.seed
176180
print(f'Setting random seed to {seed}.')
177181
np.random.seed(seed)
182+
178183
# Check given degradation_kwargs
179184
assert (ARGS.degradation_kwargs is None or
180185
ARGS.degradation_kwarg_json is None), ("Don't specify both "
@@ -185,6 +190,18 @@ def parse_args(args_input=None):
185190
degradation_kwargs = parse_degradation_kwargs(
186191
ARGS.degradation_kwarg_json
187192
)
193+
194+
# Load config
195+
if ARGS.config is not None:
196+
with open(ARGS.config, 'r') as file:
197+
config = json.load(file)
198+
if ARGS.verbose:
199+
print(f'Loading from config file {ARGS.config}.')
200+
if 'degradation_dist' in config:
201+
ARGS.degradation_dist = np.array(config['degradation_dist'])
202+
ARGS.degradations = list(degradations.DEGRADATIONS.keys())
203+
if 'clean_prop' in config:
204+
ARGS.clean_prop = config['clean_prop']
188205
# Warn user they specified kwargs for degradation not being used
189206
for deg, args in degradation_kwargs.items():
190207
if deg not in ARGS.degradations and len(args) > 0:

mdtk/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Description
22

3-
The accompanying paper (submitted to ICASSP, available upon request) referenced
3+
The accompanying paper (submitted to ISMIR) referenced
44
below is "Symbolic Music Correction using The MIDI Degradation Toolkit" and
55
describes the toolkit and its motivation in detail.
66

mdtk/data_structures.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ def fix_overlaps(df):
298298
df.loc[prev_idx, 'offset'] = note.onset
299299
current_offset = max(current_offset, note.offset)
300300
df.loc[idx, 'offset'] = current_offset
301+
else:
302+
# No overlap. Update latest offset.
303+
current_offset = note.offset
301304
# Always iterate, but no need to update current_offset here,
302305
# because it will definitely be < next_note.onset (because sorted).
303306
prev_idx = idx

mdtk/degradations.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
MIN_DURATION_DEFAULT = 50
1818
MAX_DURATION_DEFAULT = np.inf
1919

20+
MAX_GAP_DEFAULT = 50
21+
2022
TRIES_DEFAULT = 10
2123

2224
TRIES_WARN_MSG = ("WARNING: Generated invalid (overlapping) degraded excerpt "
@@ -259,7 +261,7 @@ def pitch_shift(excerpt, min_pitch=MIN_PITCH_DEFAULT,
259261
'distribution[zero_idx] to 0). Returning None.')
260262
return None
261263

262-
degraded = excerpt
264+
degraded = excerpt.copy()
263265

264266
# Sample a random note
265267
note_index = valid_notes[randint(len(valid_notes))]
@@ -395,7 +397,7 @@ def time_shift(excerpt, min_shift=MIN_SHIFT_DEFAULT,
395397
else:
396398
onset = split_range_sample([(eeo, leo), (elo, llo)])
397399

398-
degraded = excerpt
400+
degraded = excerpt.copy()
399401

400402
degraded.loc[index, 'onset'] = onset
401403

@@ -575,7 +577,7 @@ def onset_shift(excerpt, min_shift=MIN_SHIFT_DEFAULT,
575577
# No alignment
576578
onset = split_range_sample([(elo, llo), (eso, lso)])
577579

578-
degraded = excerpt
580+
degraded = excerpt.copy()
579581

580582
degraded.loc[index, 'onset'] = onset
581583
degraded.loc[index, 'dur'] = offset[index] - onset
@@ -706,7 +708,7 @@ def offset_shift(excerpt, min_shift=MIN_SHIFT_DEFAULT,
706708
else:
707709
duration = split_range_sample([(ssd, lsd), (sld, lld)])
708710

709-
degraded = excerpt
711+
degraded = excerpt.copy()
710712

711713
degraded.loc[index, 'dur'] = duration
712714

@@ -889,7 +891,7 @@ def add_note(excerpt, min_pitch=MIN_PITCH_DEFAULT, max_pitch=MAX_PITCH_DEFAULT,
889891
'dur': duration,
890892
'track': track}
891893

892-
degraded = excerpt
894+
degraded = excerpt.copy()
893895
degraded = degraded.append(note, ignore_index=True)
894896

895897
# Check if overlaps
@@ -977,7 +979,7 @@ def split_note(excerpt, min_duration=MIN_DURATION_DEFAULT, num_splits=1,
977979
onsets[i] = int(round(this_onset))
978980
durs[i] = int(round(next_onset)) - int(round(this_onset))
979981

980-
degraded = excerpt
982+
degraded = excerpt.copy()
981983
degraded.loc[note_index]['dur'] = int(round(short_duration_float))
982984
new_df = pd.DataFrame({'onset': onsets,
983985
'track': tracks,
@@ -991,8 +993,8 @@ def split_note(excerpt, min_duration=MIN_DURATION_DEFAULT, num_splits=1,
991993

992994

993995
@set_random_seed
994-
def join_notes(excerpt, max_gap=50, max_notes=20, only_first=False,
995-
tries=TRIES_DEFAULT):
996+
def join_notes(excerpt, max_gap=MAX_GAP_DEFAULT, max_notes=20,
997+
only_first=False, tries=TRIES_DEFAULT):
996998
"""
997999
Combine two notes of the same pitch and track into one.
9981000
@@ -1083,7 +1085,7 @@ def join_notes(excerpt, max_gap=50, max_notes=20, only_first=False,
10831085
start = valid_starts[index]
10841086
nexts = valid_nexts[index]
10851087

1086-
degraded = excerpt
1088+
degraded = excerpt.copy()
10871089

10881090
# Extend first note
10891091
degraded.loc[start]['dur'] = (degraded.loc[nexts[-1]]['onset'] +

mdtk/degrader.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""A degrader object can be used to easily degrade data points on the fly
2+
according to some given parameters."""
3+
import json
4+
import numpy as np
5+
import warnings
6+
7+
import mdtk.degradations as degs
8+
9+
class Degrader():
10+
"""A Degrade object can be used to easily degrade musical excerpts
11+
on the fly."""
12+
13+
def __init__(self, seed=None, degradations=list(degs.DEGRADATIONS.keys()),
14+
degradation_dist=np.ones(len(degs.DEGRADATIONS)),
15+
clean_prop=1 / (len(degs.DEGRADATIONS) + 1), config=None):
16+
"""
17+
Create a new degrader with the given parameters.
18+
19+
Parameters
20+
----------
21+
seed : int
22+
A random seed for numpy.
23+
24+
degradations : list(string)
25+
A list of the names of the degradations to use (and in what order
26+
to label them).
27+
28+
degradation_dist : list(float)
29+
A list of the probability of each degradation given in
30+
degradations. This list will be normalized to sum to 1.
31+
32+
clean_prop : float
33+
The proportion of degrade calls that should return clean excerpts.
34+
35+
config : string
36+
The path of a json config file (created by measure_errors.py).
37+
If given, degradations, degradation_dist, and clean_prop will
38+
all be overwritten by the values in the json file.
39+
"""
40+
if seed is not None:
41+
np.random.seed(seed)
42+
43+
# Load config
44+
if config is not None:
45+
with open(config, 'r') as file:
46+
config = json.load(file)
47+
48+
if 'degradation_dist' in config:
49+
degradation_dist = np.array(config['degradation_dist'])
50+
degradations = list(degs.DEGRADATIONS.keys())
51+
if 'clean_prop' in config:
52+
clean_prop = config['clean_prop']
53+
54+
# Check arg validity
55+
assert len(degradation_dist) == len(degradations), (
56+
"Given degradation_dist is not the same length as degradations:"
57+
f"\nlen({degradation_dist}) != len({degradations})"
58+
)
59+
assert min(degradation_dist) >= 0, ("degradation_dist values must "
60+
"not be negative.")
61+
assert sum(degradation_dist) > 0, ("Some degradation_dist value "
62+
"must be positive.")
63+
assert 0 <= clean_prop <= 1, ("clean_prop must be between 0 and 1 "
64+
"(inclusive).")
65+
66+
self.degradations = degradations
67+
self.degradation_dist = degradation_dist
68+
self.clean_prop = clean_prop
69+
self.failed = np.zeros(len(degradations))
70+
71+
72+
def degrade(self, note_df):
73+
"""
74+
Degrade the given note_df.
75+
76+
Parameters
77+
----------
78+
note_df : pd.DataFrame
79+
A note_df to degrade.
80+
81+
Returns
82+
-------
83+
degraded_df : pd.DataFrame
84+
A degraded version of the given note_df. If self.clean_prop > 0,
85+
this can be a copy of the given note_df.
86+
87+
deg_label : int
88+
The label of the degradation that was performed. 0 means none,
89+
and larger numbers mean the degradation
90+
"self.degradations[deg_label+1]" was performed.
91+
"""
92+
if self.clean_prop > 0 and np.random.rand() <= self.clean_prop:
93+
return note_df.copy(), 0
94+
95+
degraded_df = None
96+
this_deg_dist = self.degradation_dist.copy()
97+
this_failed = self.failed.copy()
98+
99+
# First, sample from failed degradations
100+
while np.any(this_failed > 0):
101+
# Select a degradation proportional to how many have failed
102+
deg_index = np.random.choice(
103+
len(self.degradations),
104+
p=this_failed / np.sum(this_failed)
105+
)
106+
deg_fun = degs.DEGRADATIONS[self.degradations[deg_index]]
107+
108+
# Try to degrade
109+
with warnings.catch_warnings():
110+
warnings.simplefilter("ignore")
111+
degraded_df = deg_fun(note_df)
112+
113+
# Check for success!
114+
if degraded_df is not None:
115+
self.failed[deg_index] -= 1
116+
return degraded_df, deg_index + 1
117+
118+
# Degradation failed -- 0 out this deg and continue
119+
this_failed[deg_index] = 0
120+
121+
# No degradations have remaining failures. Draw from standard dist
122+
while np.any(this_deg_dist > 0):
123+
# Select a degradation proportional to the distribution
124+
deg_index = np.random.choice(
125+
len(self.degradations),
126+
p=this_deg_dist / np.sum(this_deg_dist)
127+
)
128+
# This deg would have already failed in the above loop.
129+
# But we want to sample it and count it as another failure.
130+
if self.failed[deg_index] > 0:
131+
self.failed[deg_index] += 1
132+
continue
133+
deg_fun = degs.DEGRADATIONS[self.degradations[deg_index]]
134+
135+
# Try to degrade
136+
with warnings.catch_warnings():
137+
warnings.simplefilter("ignore")
138+
degraded_df = deg_fun(note_df)
139+
140+
# Check for success!
141+
if degraded_df is not None:
142+
return degraded_df, deg_index + 1
143+
144+
# Degradation failed -- add 1 to failure and continue
145+
self.failed[deg_index] += 1
146+
147+
# Here, all degradations (with dist > 0) failed
148+
return note_df.copy(), 0
149+

mdtk/filesystem_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import os
33
import sys
44
import shutil
5-
import urllib
5+
import urllib.request
6+
import urllib.error
67
import warnings
78
import zipfile
89

mdtk/tests/test_data_structures.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,16 @@
8484
'ch' :[0, 1, 0, 0, 0, 0]
8585
})
8686
note_df_complex_overlap = pd.DataFrame({
87-
'onset': [50, 75, 150, 200, 200, 300, 300, 300],
88-
'track': [0, 0, 0, 0, 0, 0, 0, 1],
89-
'pitch': [10, 10, 20, 10, 20, 30, 30, 10],
90-
'dur': [300, 25, 100, 125, 50, 50, 100, 100]
87+
'onset': [0, 50, 75, 150, 200, 200, 300, 300, 300],
88+
'track': [0, 0, 0, 0, 0, 0, 0, 0, 1],
89+
'pitch': [10, 10, 10, 20, 10, 20, 30, 30, 10],
90+
'dur': [50, 300, 25, 100, 125, 50, 50, 100, 100]
9191
})
9292
note_df_complex_overlap_fixed = pd.DataFrame({
93-
'onset': [50, 75, 150, 200, 200, 300, 300],
94-
'track': [0, 0, 0, 0, 0, 0, 1],
95-
'pitch': [10, 10, 20, 10, 20, 30, 10],
96-
'dur': [25, 125, 50, 150, 50, 100, 100]
93+
'onset': [0, 50, 75, 150, 200, 200, 300, 300],
94+
'track': [0, 0, 0, 0, 0, 0, 0, 1],
95+
'pitch': [10, 10, 10, 20, 10, 20, 30, 10],
96+
'dur': [50, 25, 125, 50, 150, 50, 100, 100]
9797
})
9898
# midinote keyboard range from 0 to 127 inclusive
9999
all_midinotes = list(range(0, 128))

0 commit comments

Comments
 (0)