Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit f7625c5

Browse files
fix end prediction; schedule batch_len
1 parent 78c2ab9 commit f7625c5

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

examples/notepredictor/generate.scd

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
MIDIClient.init
2121
MIDIClient.destinations
2222
~m1 = MIDIOut.newByName("IAC Driver", "Bus 1");
23-
~m2 = MIDIOut.newByName("IAC Driver", "IAC Bus 2");
23+
~m2 = MIDIOut.newByName("IAC Driver", "Bus 2");
24+
~m3 = MIDIOut.newByName("IAC Driver", "Bus 3");
2425
/*
2526
~m1.noteOn(0, 60, 127)
2627
~m2.noteOn(0, 60, 127)
@@ -35,7 +36,11 @@ MIDIClient.destinations
3536
var group = (prog-1 /8).asInteger;
3637
var idx = (prog-1 %8).asInteger;
3738
var port = switch(group)
38-
{ 0}{(idx<6).if{~m1}{~m2}} //piano
39+
{ 0}{case
40+
{idx<4}{~m1} //acoustic
41+
{idx<6}{~m2} //electric
42+
{true}{~m3} //harpsichord
43+
} //piano
3944
{ 1}{((idx<3)||(idx==5)).if{~m1}{~m2}} //chromatic perc
4045
{ 2}{(idx<4).if{~m1}{~m2}} //organ
4146
{ 3}{(idx<5).if{~m1}{~m2}} //guitar
@@ -69,7 +74,7 @@ MIDIClient.destinations
6974
}
7075
};
7176
~release_all = {arg vel=0;
72-
[~m1, ~m2].do{arg port; 128.do{arg note; 16.do{arg chan; port.noteOff(chan, note, vel)}}}
77+
[~m1, ~m2, ~m3].do{arg port; 128.do{arg note; 16.do{arg chan; port.noteOff(chan, note, vel)}}}
7378
};
7479
)
7580

@@ -202,6 +207,7 @@ MIDIdef.noteOn(\input_on, {
202207
//get a new prediction in light of current note
203208
b.sendMsg("/predictor/predict",
204209
\inst, inst, \pitch, num, \time, dt, \vel, val,
210+
// \fix_instrument, ~player_inst,
205211
\allow_start, false, \allow_end, false,
206212
\pitch_temp, 0.5, \rhythm_temp, 0.5, \timing_temp, 0.1,
207213
\min_time, ~delay, \max_time, 5
@@ -283,9 +289,12 @@ OSCdef(\return, {
283289
b.sendMsg("/predictor/predict",
284290
\inst, inst, \pitch, pitch, \time, dt_actual, \vel, vel,
285291
\allow_start, false, \allow_end, true,
286-
\instrument_temp, 1, \pitch_temp, 0.9, \rhythm_temp, 0.7, \timing_temp, 0.05,
292+
// \fix_instrument, ~player_inst,
293+
// \fix_time, 2.rand*0.1+~delay,
294+
\instrument_temp, 1, \pitch_temp, 0.9,
295+
\rhythm_temp, 1, \timing_temp, 0.05,
287296
// \instrument_temp, 1, \pitch_temp, 1, \rhythm_temp, 1, \timing_temp, 1,
288-
\min_time, ~delay,
297+
// \min_time, ~delay,
289298
\max_time, 5,
290299

291300
);

notepredictor/notepredictor/model.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,11 @@ def __init__(self,
204204
p.weight.mul_(1e-2)
205205
self.end_proj.weight.mul(1e-2)
206206

207-
# IDEA: instead of this, combine current embeddings (independently) with h via MLPs
208-
# stacked along a new final dim
209-
# matmul by mask, which is easier (?) to vary per batch/time
207+
# IDEA: instead of this, combine current embeddings (independently) with h by
208+
# projecting to h size, stacking with h along a new final dim,
209+
# matmul by n+1 x n mask, which is easier (?) to vary per batch/time
210210
# (compared to permute-and-cumsum)
211-
# then tanh, unbind and more independent MLPs -> dist params
211+
# then tanh, unbind and independent MLPs -> dist params
212212
self.xformer = ModalityTransformer(emb_size, ar_hidden, ar_heads, ar_layers)
213213

214214
# persistent RNN state for inference
@@ -260,14 +260,13 @@ def forward(self, instruments, pitches, times, velocities, ends, validation=Fals
260260
t.expand(self.rnn.num_layers, x.shape[0], -1).contiguous() # 1 x batch x hidden
261261
for t in self.initial_state)
262262
h, _ = self.rnn(x, initial_state) #batch, time, hidden_size
263-
h = h[:,:-1] # skip last time position
264263

265264
# fit all note factorizations (e.g. pitch->time->vel vs vel->time->pitch)
266265
# TODO: perm each batch item independently?
267266
# get a random ordering for note modalities:
268267
perm = torch.randperm(self.note_dim)
269268
# chunk RNN state into Transformer inputs
270-
hs = list(self.h_proj(h).chunk(self.note_dim+1, -1))
269+
hs = list(self.h_proj(h[:,:-1]).chunk(self.note_dim+1, -1)) # skip last time position
271270
h_ctx = hs[0]
272271
h_tgt = [hs[i+1] for i in perm]
273272
# embed ground truth values for teacher-forcing
@@ -299,10 +298,11 @@ def forward(self, instruments, pitches, times, velocities, ends, validation=Fals
299298
vel_log_probs = vel_result.pop('log_prob')
300299

301300
# end prediction
302-
# skip the last position for convenience (so masking is the same)
303-
end_params = self.end_proj(h)
301+
# skip the first position for convenience
302+
# (so masking is the same for end as for note parts)
303+
end_params = self.end_proj(h[:,1:])
304304
end_logits = F.log_softmax(end_params, -1)
305-
end_log_probs = end_logits.gather(-1, ends[:,:-1,None])[...,0]
305+
end_log_probs = end_logits.gather(-1, ends[:,1:,None])[...,0]
306306

307307
r = {
308308
'end_log_probs': end_log_probs,
@@ -553,6 +553,7 @@ def predict(self,
553553
pred_vel = predicted[iperm[3]]
554554

555555
end_params = self.end_proj(h)
556+
print(end_params)
556557
end = D.Categorical(logits=end_params).sample()
557558
if not allow_end:
558559
end[:] = 0

notepredictor/scripts/train_notes.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def __init__(self,
2525
model = None, # dict of model constructor overrides
2626
batch_size = 128,
2727
batch_len = 64,
28+
batch_len_schedule = None,
29+
batch_len_max = 512,
2830
lr = 3e-4,
2931
adam_betas = (0.9, 0.999),
3032
adam_eps = 1e-08,
@@ -209,7 +211,7 @@ def validate():
209211
vel = batch['velocity'].to(self.device, non_blocking=True)
210212

211213
self.iteration += 1
212-
self.exposure += self.batch_size # * self.batch_len
214+
self.exposure += self.batch_size * self.batch_len
213215
logs = {}
214216

215217
### forward+backward+optimizer step ###
@@ -232,6 +234,11 @@ def validate():
232234

233235
validate()
234236

237+
if self.batch_len_schedule is not None:
238+
self.batch_len = min(
239+
self.batch_len_max, self.batch_len+self.batch_len_schedule)
240+
self.dataset.batch_len = self.batch_len
241+
235242
self.save(self.model_dir / f'{self.epoch:04d}.ckpt')
236243

237244
def deep_update(a, b):

0 commit comments

Comments
 (0)