Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 78c2ab9

Browse files
mask loss for padding; redo preprocessing and dataloading to correctly randomize order of concurrent events
1 parent 69863b5 commit 78c2ab9

File tree

5 files changed

+1060
-64
lines changed

5 files changed

+1060
-64
lines changed

notepredictor/notebook/midi.ipynb

Lines changed: 951 additions & 11 deletions
Large diffs are not rendered by default.

notepredictor/notepredictor/data.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.utils.data import Dataset, DataLoader
88

99
class MIDIDataset(Dataset):
10-
def __init__(self, data_dir, batch_len, transpose=2, speed=0.1, glob='**/*.pkl'):
10+
def __init__(self, data_dir, batch_len, transpose=5, speed=0.1, glob='**/*.pkl'):
1111
#, clamp_time=(-,10)):
1212
"""
1313
"""
@@ -56,15 +56,13 @@ def __getitem__(self, idx):
5656
transpose = random.randint(-transpose_down, transpose_up)
5757
pitch = pitch + transpose
5858

59-
# random speed
60-
# delta t of first note?
61-
time = time.float()
59+
60+
time_margin = 1e-3 # hardcoded since it should match prep script
61+
62+
# dequantize: add noise up to +/- margin
63+
time = time + (torch.rand_like(time)*2-1)*time_margin
64+
# random augment tempo
6265
time = time * (1 + random.random()*self.speed*2 - self.speed)
63-
# dequantize
64-
# TODO: use actual tactus from MIDI file?
65-
time = (
66-
time + (torch.rand_like(time)-0.5)*2e-3
67-
).clamp(0., float('inf'))
6866

6967
# dequantize velocity
7068
velocity = velocity.float()
@@ -77,8 +75,18 @@ def __getitem__(self, idx):
7775
velocity = velocity ** (2**(torch.randn((1,))/3))
7876
velocity *= 127
7977

78+
# sort (using argsort on time and indexing the rest)
79+
# compute delta time
80+
time, idx = time.sort()
81+
time = torch.cat((time.new_zeros((1,)), time)).diff(1)
82+
program = program[idx]
83+
pitch = pitch[idx]
84+
velocity = velocity[idx]
85+
8086
# pad with start tokens, zeros
81-
pad = max(0, self.batch_len-len(pitch))
87+
# always pad with batch_len so that end tokens don't appear in a biased
88+
# location
89+
pad = self.batch_len-1#max(0, self.batch_len-len(pitch))
8290
program = torch.cat((
8391
program.new_full((1,), self.prog_start_token),
8492
program,
@@ -95,13 +103,13 @@ def __getitem__(self, idx):
95103
velocity.new_zeros((1,)),
96104
velocity,
97105
velocity.new_zeros((pad,))))
98-
# end signal: nonzero for last event + padding
106+
# end signal: nonzero for last event
99107
end = torch.zeros_like(program)
100108
end[-pad-1:] = 1
101-
102-
mask = torch.zeros_like(program)
109+
# compute binary mask for the loss
110+
mask = torch.ones_like(program, dtype=torch.bool)
103111
if pad > 0:
104-
mask[-pad:] = 1
112+
mask[-pad:] = False
105113

106114
# random slice
107115
i = random.randint(0, len(pitch)-self.batch_len)

notepredictor/notepredictor/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def __init__(self,
204204
p.weight.mul_(1e-2)
205205
self.end_proj.weight.mul(1e-2)
206206

207+
# IDEA: instead of this, combine current embeddings (independently) with h via MLPs
208+
# stacked along a new final dim
209+
# matmul by mask, which is easier (?) to vary per batch/time
210+
# (compared to permute-and-cumsum)
211+
# then tanh, unbind and more independent MLPs -> dist params
207212
self.xformer = ModalityTransformer(emb_size, ar_hidden, ar_heads, ar_layers)
208213

209214
# persistent RNN state for inference
@@ -255,14 +260,14 @@ def forward(self, instruments, pitches, times, velocities, ends, validation=Fals
255260
t.expand(self.rnn.num_layers, x.shape[0], -1).contiguous() # 1 x batch x hidden
256261
for t in self.initial_state)
257262
h, _ = self.rnn(x, initial_state) #batch, time, hidden_size
263+
h = h[:,:-1] # skip last time position
258264

259265
# fit all note factorizations (e.g. pitch->time->vel vs vel->time->pitch)
260266
# TODO: perm each batch item independently?
261267
# get a random ordering for note modalities:
262268
perm = torch.randperm(self.note_dim)
263269
# chunk RNN state into Transformer inputs
264-
hs = self.h_proj(h[:,:-1]) # skip last time position
265-
hs = list(hs.chunk(self.note_dim+1, -1))
270+
hs = list(self.h_proj(h).chunk(self.note_dim+1, -1))
266271
h_ctx = hs[0]
267272
h_tgt = [hs[i+1] for i in perm]
268273
# embed ground truth values for teacher-forcing
@@ -294,9 +299,10 @@ def forward(self, instruments, pitches, times, velocities, ends, validation=Fals
294299
vel_log_probs = vel_result.pop('log_prob')
295300

296301
# end prediction
302+
# skip the last position for convenience (so masking is the same)
297303
end_params = self.end_proj(h)
298304
end_logits = F.log_softmax(end_params, -1)
299-
end_log_probs = end_logits.gather(-1, ends[:,:,None])[...,0]
305+
end_log_probs = end_logits.gather(-1, ends[:,:-1,None])[...,0]
300306

301307
r = {
302308
'end_log_probs': end_log_probs,

notepredictor/scripts/lakh_prep_multitrack.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22
from multiprocessing import Pool
33
import functools as ft
4+
from collections import defaultdict
45
import random
56

67
from tqdm import tqdm
@@ -16,33 +17,68 @@ def process(fnames):
1617
except Exception:
1718
return
1819

19-
inst_events = []
20+
# fix overlapping notes and add a margin for
21+
# dequantization at data loading time
22+
time_margin = 1e-3
23+
24+
events = []
25+
# for each instrument
2026
for inst in mid.instruments:
2127
inst.remove_invalid_notes()
2228
program = inst.program + 128*inst.is_drum
23-
# NOTE: this will sort concurrent events by pitch
24-
# which will introduce some bias when interacting with the model?
25-
# e.g. if user plays a note, it will never be harmonized below (or only)
26-
# with inexact timing, less frequently
27-
# similarly the pitch order would correlate with instrument, i.e. bass
28-
# would usually play first
29-
# if anything descending pitch might sound better
30-
# could randomize -- even better would be to randomize in dataloader
31-
# might be expensive though
32-
note_ons = [(n.start, n.pitch, n.velocity, program) for n in inst.notes]
33-
note_offs = [(n.end, n.pitch, 0, program) for n in inst.notes]
34-
inst_events.extend(note_ons+note_offs)
35-
if len(inst_events) < 64:
29+
30+
# break out by pitch
31+
nbp = defaultdict(list)
32+
for n in inst.notes:
33+
nbp[n.pitch].append(n)
34+
35+
# shorten all notes so they end 2*$margin before next (within pitch)
36+
for seq in nbp.values():
37+
for i,n in enumerate(seq[:-1]):
38+
max_end = seq[i+1].start-2*time_margin
39+
if n.end > max_end:
40+
n.end = max_end
41+
# and flatten again
42+
# converting note offs to 0 velocity
43+
# also prevent any note ons from having 0 velocity
44+
events.append((n.start, n.pitch, max(1, n.velocity), program))
45+
events.append((n.end, n.pitch, 0, program))
46+
47+
if len(events) < 64:
3648
return
37-
time, pitch, vel, prog = zip(*sorted(inst_events))
38-
delta = torch.FloatTensor([0, *time]).diff(1)
49+
50+
time, pitch, vel, prog = zip(*events)
3951
torch.save(dict(
40-
time=delta,
52+
time=torch.FloatTensor(time),
4153
pitch=torch.LongTensor(pitch),
4254
velocity=torch.LongTensor(vel),
4355
program=torch.LongTensor(prog)
4456
), g.with_suffix('.pkl'))
4557

58+
59+
# # NOTE: this will sort concurrent events by pitch
60+
# # which will introduce some bias when interacting with the model?
61+
# # e.g. if user plays a note, it will never be harmonized below (or only)
62+
# # with inexact timing, less frequently
63+
# # similarly the pitch order would correlate with instrument, i.e. bass
64+
# # would usually play first
65+
# # if anything descending pitch might sound better
66+
# # could randomize -- even better would be to randomize in dataloader
67+
# # might be expensive though
68+
# note_ons = [(n.start, n.pitch, n.velocity, program) for n in inst.notes]
69+
# note_offs = [(n.end, n.pitch, 0, program) for n in inst.notes]
70+
# inst_events.extend(note_ons+note_offs)
71+
# if len(inst_events) < 64:
72+
# return
73+
# time, pitch, vel, prog = zip(*sorted(inst_events))
74+
# delta = torch.FloatTensor([0, *time]).diff(1)
75+
# torch.save(dict(
76+
# time=delta,
77+
# pitch=torch.LongTensor(pitch),
78+
# velocity=torch.LongTensor(vel),
79+
# program=torch.LongTensor(prog)
80+
# ), g.with_suffix('.pkl'))
81+
4682
def main(data_path, dest_path, n_jobs=4):
4783
data_dir = Path(data_path)
4884
files = list(data_dir.glob('**/*.mid'))

notepredictor/scripts/train_notes.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def __init__(self,
2323
log_dir,
2424
data_dir,
2525
model = None, # dict of model constructor overrides
26-
# clamp_time = (0,10), # given to trainer because it needs to go to dataset+model
2726
batch_size = 128,
2827
batch_len = 64,
2928
lr = 3e-4,
@@ -73,7 +72,7 @@ def __init__(self,
7372

7473
# Trainer state
7574
self.iteration = 0
76-
self.exposure = 0
75+
self.exposure = 0 # TODO: measure in events, no batch items
7776
self.epoch = 0
7877

7978
# construct model from arguments
@@ -139,14 +138,15 @@ def process_grad(self):
139138
self.model.parameters(), self.grad_clip, error_if_nonfinite=True)
140139
return r
141140

142-
def get_loss_components(self, result):
143-
# TODO: masking
141+
def get_loss_components(self, result, mask):
142+
def reduce(k):
143+
return result[k].masked_select(mask).mean()
144144
return {
145-
'instrument_nll': -result['instrument_log_probs'].mean(),
146-
'pitch_nll': -result['pitch_log_probs'].mean(),
147-
'time_nll': -result['time_log_probs'].mean(),
148-
'velocity_nll': -result['velocity_log_probs'].mean(),
149-
'end_nll': -result['end_log_probs'].mean()
145+
'instrument_nll': -reduce('instrument_log_probs'),
146+
'pitch_nll': -reduce('pitch_log_probs'),
147+
'time_nll': -reduce('time_log_probs'),
148+
'velocity_nll': -reduce('velocity_log_probs'),
149+
'end_nll': -reduce('end_log_probs'),
150150
}
151151

152152
def train(self):
@@ -165,6 +165,7 @@ def validate():
165165
metrics = defaultdict(float)
166166
self.model.eval()
167167
for batch in tqdm(valid_loader, desc=f'validating epoch {self.epoch}'):
168+
mask = batch['mask'].to(self.device, non_blocking=True)[...,1:]
168169
end = batch['end'].to(self.device, non_blocking=True)
169170
inst = batch['instrument'].to(self.device, non_blocking=True)
170171
pitch = batch['pitch'].to(self.device, non_blocking=True)
@@ -173,18 +174,19 @@ def validate():
173174
with torch.no_grad():
174175
result = self.model(
175176
inst, pitch, time, vel, end, validation=True)
176-
losses = {k:v.item() for k,v in self.get_loss_components(result).items()}
177+
losses = {k:v.item() for k,v in self.get_loss_components(
178+
result, mask).items()}
177179
metrics['loss'] += sum(losses.values())
178180
for k,v in losses.items():
179181
metrics[k] += v
180182
metrics['instrument_acc'] += (result['instrument_log_probs']
181-
.exp().mean().item())
183+
.masked_select(mask).exp().mean().item())
182184
metrics['pitch_acc'] += (result['pitch_log_probs']
183-
.exp().mean().item())
185+
.masked_select(mask).exp().mean().item())
184186
metrics['time_acc_30ms'] += (result['time_acc_30ms']
185-
.mean().item())
187+
.masked_select(mask).mean().item())
186188
metrics['velocity_acc'] += (result['velocity_log_probs']
187-
.exp().mean().item())
189+
.masked_select(mask).exp().mean().item())
188190
self.log('valid', {k:v/len(valid_loader) for k,v in metrics.items()})
189191

190192
epoch_size = self.epoch_size or len(train_loader)
@@ -199,29 +201,33 @@ def validate():
199201
self.model.train()
200202
for batch in tqdm(it.islice(train_loader, epoch_size),
201203
desc=f'training epoch {self.epoch}', total=epoch_size):
202-
204+
mask = batch['mask'].to(self.device, non_blocking=True)
203205
end = batch['end'].to(self.device, non_blocking=True)
204206
inst = batch['instrument'].to(self.device, non_blocking=True)
205207
pitch = batch['pitch'].to(self.device, non_blocking=True)
206208
time = batch['time'].to(self.device, non_blocking=True)
207209
vel = batch['velocity'].to(self.device, non_blocking=True)
208210

209211
self.iteration += 1
210-
self.exposure += self.batch_size
211-
212+
self.exposure += self.batch_size # * self.batch_len
212213
logs = {}
213214

215+
### forward+backward+optimizer step ###
214216
self.opt.zero_grad()
215217
result = self.model(inst, pitch, time, vel, end)
216-
losses = self.get_loss_components(result)
218+
losses = self.get_loss_components(result, mask[...,1:])
217219
loss = sum(losses.values())
218220
loss.backward()
219221
logs |= self.process_grad()
220222
self.opt.step()
223+
########
221224

225+
# log loss components
222226
logs |= {k:v.item() for k,v in losses.items()}
223-
logs |= {k:v.item() for k,v in result.items() if v.numel()==1}
227+
# log total loss
224228
logs |= {'loss':loss.item()}
229+
# log any other returned scalars
230+
logs |= {k:v.item() for k,v in result.items() if v.numel()==1}
225231
self.log('train', logs)
226232

227233
validate()

0 commit comments

Comments
 (0)