Skip to content

Commit 1d9255c

Browse files
committed
Move TODOs to Issues
Issues now written to address later: #98, #99, #100, #101, #102 issues #103 #104 #105 #106 #107 and #108
1 parent 7038857 commit 1d9255c

File tree

8 files changed

+3
-73
lines changed

8 files changed

+3
-73
lines changed

make_dataset.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def parse_args(args_input=None):
9999
parser.add_argument('--local-midi-dirs', metavar='midi_dir', type=str,
100100
nargs='*', help='directories containing midi files to '
101101
'include in the dataset', default=[])
102-
# TODO: check this works!
103102
parser.add_argument('--local-csv-dirs', metavar='csv_dir', type=str,
104103
nargs='*', help='directories containing csv files to '
105104
'include in the dataset', default=[])
@@ -114,12 +113,10 @@ def parse_args(args_input=None):
114113
'--download-cache-dir and --clear-download-cache. To '
115114
'download no data, provide an input of "None"',
116115
)
117-
# TODO: check this works!
118116
parser.add_argument('--download-cache-dir', type=str,
119117
default=downloaders.DEFAULT_CACHE_PATH, help='The '
120118
'directory to use for storing intermediate downloaded '
121119
'data e.g. zip files, and prior to preprocessing.')
122-
# TODO: check this works!
123120
parser.add_argument('--clear-download-cache', action='store_true',
124121
help='clear downloaded data cache')
125122
parser.add_argument('--degradations', metavar='deg_name', nargs='*',
@@ -136,7 +133,6 @@ def parse_args(args_input=None):
136133
parser.add_argument('--min-notes', metavar='N', type=int, default=10,
137134
help='The minimum number of notes required for an '
138135
'excerpt to be valid.')
139-
# TODO: check this works!
140136
parser.add_argument('--degradation-kwargs', metavar='json_string',
141137
help='json with keyword arguments for the '
142138
'degradation functions. First provide the degradation '
@@ -145,7 +141,6 @@ def parse_args(args_input=None):
145141
'kwarg. e.g. {"pitch_shift__distribution": "poisson", '
146142
'"pitch_shift__min_pitch: 5"}',
147143
type=json.loads, default=None)
148-
# TODO: check this works!
149144
parser.add_argument('--degradation-kwarg-json', metavar='json_file',
150145
help='A file containing parameters as described in '
151146
'--degradation-kwargs. If this file is given, '
@@ -242,7 +237,6 @@ def parse_args(args_input=None):
242237

243238

244239
# Instantiate downloaders =================================================
245-
# TODO: make OVERWRITE this an arg for the script
246240
OVERWRITE = None
247241
ds_names = ARGS.datasets
248242
if len(ds_names) == 1 and ds_names[0].lower() == 'none':
@@ -568,5 +562,4 @@ def parse_args(args_input=None):
568562

569563
print('\nTo reproduce this dataset again, run the script with argument '
570564
f'--seed {seed}')
571-
#TODO: print('see the examples directory for baseline models using this data')
572565
print(LOGO)

mdtk/data_structures.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ def fix_overlapping_notes(df):
230230
bad_note = np.append(bad_note, False) # last note always fine
231231
df.loc[bad_note, 'dur'] = (next_note_on[bad_note[:-1]]
232232
- df.loc[bad_note, 'onset'].values)
233-
# TODO: add an assertion to catch dur==0 and add a test
234233
return df
235234

236235

@@ -246,7 +245,6 @@ def fix_overlapping_pitches(df):
246245
bad_note = np.append(bad_note, False) # last note always fine
247246
df.loc[bad_note, 'dur'] = (next_note_on[bad_note[:-1]]
248247
- df.loc[bad_note, 'onset'].values)
249-
# TODO: add an assertion to catch dur==0 and add a test
250248
return df
251249

252250

@@ -930,10 +928,6 @@ def __init__(self, note_df=None, csv_path=None,
930928
self.note_df = note_df
931929
# We do not assume that the supplied note_df is correctly formed,
932930
# and simply bomb out if it is not
933-
# TODO: implement df methods to fix issues instead e.g. overlaps.
934-
# Copy code from read_note_csv. e.g.:
935-
# * reorder columns
936-
# * if all columns but track and no extra cols, assume 1 trk
937931
if self.monophonic_tracks is not None:
938932
make_monophonic(self.note_df, tracks=monophonic_tracks)
939933
if self.max_note_len is not None:

mdtk/downloaders.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919

2020

2121
# Classes =====================================================================
22-
# TODO: make attributes useful to users standard e.g. beat-aligned=True/False
23-
# TODO: some things are likely to be important re preprocessing e.g. the unit
24-
# for the onset and duration of notes. Add these as attributes too.
2522
class DataDownloader:
2623
"""Base class for data downloaders"""
2724
def __init__(self, cache_path=DEFAULT_CACHE_PATH):
@@ -60,20 +57,14 @@ def download_csv(self, output_path, cache_path=None, overwrite=None,
6057
'implement the download_csv method.')
6158

6259

63-
# TODO: since these datasets already have CSV in the right format, we should
64-
# implement download_csv() methods to use in favour of the download_midi
65-
# this method would reformat csv to be correct (cols need renaming etc.)
66-
# TODO: handle conversion from quarters to ms - use tempo data, but use min/max
67-
# tempo values as some were a bit spurious
60+
6861
class PPDDSep2018Monophonic(DataDownloader):
6962
"""Patterns for Preditction Development Dataset. Monophonic data only.
7063
7164
References
7265
----------
7366
https://www.music-ir.org/mirex/wiki/2019:Patterns_for_Prediction
7467
"""
75-
# TODO: add 'sample_size', to allow only a small random sample of the
76-
# total midi files to be copied to the output
7768
def __init__(self, cache_path=DEFAULT_CACHE_PATH,
7869
sizes=['small', 'medium', 'large'], clean=False):
7970
super().__init__(cache_path = cache_path)
@@ -136,8 +127,6 @@ class PPDDSep2018Polyphonic(PPDDSep2018Monophonic):
136127
----------
137128
https://www.music-ir.org/mirex/wiki/2019:Patterns_for_Prediction
138129
"""
139-
# TODO: add 'sample_size', to allow only a small random sample of the
140-
# total midi files to be copied to the output
141130
def __init__(self, cache_path=DEFAULT_CACHE_PATH,
142131
sizes=['small', 'medium', 'large'], clean=False):
143132
super().__init__(cache_path=cache_path, sizes=sizes, clean=clean)

mdtk/filesystem_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ def make_directory(path, overwrite=None, verbose=False):
6767

6868

6969
def extract_zip(zip_path, out_path, overwrite=None, verbose=False):
70-
"""Convenience function to extract zip file to out_path.
71-
TODO: make work for all types of zip files."""
70+
"""Convenience function to extract zip file to out_path."""
7271
if verbose:
7372
print(f'Extracting {zip_path} to {out_path}')
7473
dirname = os.path.splitext(os.path.basename(zip_path))[0]

mdtk/formatters.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ def diff_pd(df1, df2):
3131
return pd.DataFrame({'from': changed_from, 'to': changed_to},
3232
index=changed.index)
3333

34-
# TODO: later can auto detect vocab from corpus if necessary
35-
# I'm doing things this way just for ability to change things
36-
# later with ease
34+
3735
class CommandVocab(object):
3836
def __init__(self, min_pitch=MIN_PITCH_DEFAULT,
3937
max_pitch=MAX_PITCH_DEFAULT,
@@ -351,7 +349,6 @@ def df_to_command_str(df, min_pitch=MIN_PITCH_DEFAULT, max_pitch=MAX_PITCH_DEFAU
351349
assert time_increment > 0, "time_increment must be positive."
352350
assert max_time_shift > 0, "max_time_shift must be positive."
353351

354-
# TODO: This rounding may result in notes of length 0.
355352
note_off = df.loc[:, ['onset', 'pitch']]
356353
note_off['onset'] = note_off['onset'] + df['dur']
357354
note_off['cmd'] = note_off['pitch'].apply(lambda x: f'f{x}')

mdtk/pytorch_models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size=2,
2929
self.vocab_size = vocab_size
3030

3131
self.embedding = nn.Embedding(vocab_size, embedding_dim)
32-
# TODO: try getting batch_first to work on this model
3332
self.lstm = nn.LSTM(embedding_dim, hidden_dim,
3433
num_layers=num_lstm_layers)
3534

@@ -49,7 +48,6 @@ def forward(self, batch, input_lengths=None):
4948
device = batch.device
5049
self.hidden = self.init_hidden(batch_size, device=device)
5150
embeds = self.embedding(batch).permute(1, 0, 2)
52-
# TODO: try getting batch_first to work on this model
5351
# embeds = self.embedding(batch)
5452
outputs, (ht, ct) = self.lstm(embeds, self.hidden)
5553
# ht is the last hidden state of the sequences

mdtk/pytorch_trainers.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010

1111

1212

13-
# TODO: I don't like the fomatter being passed in here - would prefer these
14-
# Trainers to be more general except for the iteration method for which you
15-
# hardcode how to do the train/test iteration.
1613
class BaseTrainer:
1714
"""Provides methods to train pytorch models. Adapted from:
1815
https://github.com/codertimo/BERT-pytorch/blob/master/bert_pytorch/trainer/pretrain.py"""
@@ -137,7 +134,6 @@ def save(self, file_path=None, epoch=None):
137134
print(f"Model saved {output_path}")
138135
return output_path
139136

140-
# TODO: implement load method (for use with load from checkpoint)
141137

142138

143139
class ErrorDetectionTrainer(BaseTrainer):
@@ -706,13 +702,6 @@ def iteration(self, epoch, data_loader, train=True, evaluate=False):
706702
total_data_points += len(input_data)
707703
for in_data, out_data, clean_data in \
708704
zip(input_data, model_output, labels):
709-
# TODO: Only 1 of these calls is necessary. deg and clean
710-
# could conceivably be returned by the data loader.
711-
# N.B. Currently, the precise min and max pitch don't
712-
# matter here. The converter just treats them all the same,
713-
# corrects and warns if the range doesn't make sense.
714-
# However, if loading deg and clean from the original df,
715-
# using the correct min and max pitch will be important.
716705
with warnings.catch_warnings():
717706
warnings.simplefilter("ignore")
718707
deg_df = self.formatter['model_to_df'](

mdtk/tests/test_data_structures.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -445,18 +445,7 @@ def test_pianoroll_all_pitches():
445445
assert (pianoroll == np.ones((1, 2, 128, 1), dtype='uint8')).all()
446446

447447

448-
# TODO: test all note_on occur with sounding
449-
450-
# TODO: test all note_off occur with sounding
451-
452-
# TODO: test all sounding begin note_on and end_note_off
453-
454-
# TODO: test all methods in pianoroll and all attributes
455-
456-
457448
# Composition class tests =====================================================
458-
# TODO: write import from csv tests
459-
460449
def test_composition_df_assertions():
461450
"""Essentially the same tests as test_check_note_df"""
462451
assertion = False
@@ -504,16 +493,6 @@ def test_composition_all_pitches():
504493

505494

506495

507-
# TODO: reimplement this if and when we implement auto fix of note_df
508-
#def test_auto_sort_onset_and_pitch():
509-
# comp = Composition(note_df=note_df_2pitch_aligned, fix_note_df=True)
510-
# assert comp.note_df.equals(
511-
# note_df_2pitch_aligned
512-
# .sort_values(['onset', 'pitch'])
513-
# .reset_index(drop=True)
514-
# )
515-
516-
517496
def test_not_ending_in_silence():
518497
for df in ALL_VALID_DF.values():
519498
comp = Composition(note_df=df)
@@ -555,19 +534,11 @@ def test_composition_read_csv():
555534
comp.plot()
556535
comp.synthesize()
557536

558-
def test_csv_and_df_imports_same():
559-
# TODO: write test that imports from all csvs and checks same as
560-
# importing from df
561-
pass
562537

563538

564-
# TODO: Check if anything alters input data - loop over all functions and
565-
# methods
566539

567540

568541
# Cleanup =====================================================================
569-
# TODO: This isn't technichally a test...should probably be some other function
570-
# look up the proper way to do this.
571542
def test_remove_csvs():
572543
for csv in ALL_CSV:
573544
os.remove(csv)

0 commit comments

Comments
 (0)