Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 68b28a7

Browse files
multitrack model
1 parent f4394f8 commit 68b28a7

File tree

6 files changed

+247
-136
lines changed

6 files changed

+247
-136
lines changed

examples/notepredictor/generate.scd

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ OSCdef(\return, {
5151
}, '/prediction', nil);
5252
t = Process.elapsedTime;
5353
b.sendMsg("/predictor/predict",
54-
\pitch, 60+12.rand, \time, 0, \vel, 0,
54+
\inst, 0, \pitch, 60+12.rand, \time, 0, \vel, 0,
5555
\pitch_temp, 0.5, \rhythm_temp, 0.5, \timing_temp, 0.1
5656
);
5757
)
5858

5959
// set the delay for more precise timing
60-
~delay = 0.016;
60+
~delay = 0.02;
6161

6262
// duet with the model
6363
// feeds the model's predictions back to it as well as player input
@@ -94,13 +94,14 @@ MIDIdef.noteOff(\input_off, {
9494
arg val, num, chan, src;
9595
var t2 = Process.elapsedTime;
9696
var dt = t2-(t?(t2-~delay)); //time since last note
97+
var inst = ~player_inst;
9798

9899
// cancel any pending predictions
99100
SystemClock.clear;
100101
~pending_predictions.postln;
101102
//get a new prediction in light of current note
102103
b.sendMsg("/predictor/predict",
103-
\pitch, num, \time, dt, \vel, 0,
104+
\inst, inst, \pitch, num, \time, dt, \vel, 0,
104105
\allow_start, false, \allow_end, false,
105106
\pitch_temp, 0.5, \rhythm_temp, 0.5, \timing_temp, 0.1,
106107
\min_time, ~delay, \max_time, 5
@@ -113,7 +114,7 @@ MIDIdef.noteOff(\input_off, {
113114
~synths[num] = nil;
114115

115116
// post the current note
116-
[\player, dt, num, 0].postln;
117+
[\player, dt, inst, num, 0].postln;
117118

118119
// mark time of current note
119120
t = t2;
@@ -128,13 +129,14 @@ MIDIdef.noteOn(\input_on, {
128129
arg val, num, chan, src;
129130
var t2 = Process.elapsedTime;
130131
var dt = t2-(t?(t2-~delay)); //time since last note
132+
var inst = ~player_inst;
131133

132134
// cancel any pending predictions
133135
SystemClock.clear;
134136
~pending_predictions.postln;
135137
//get a new prediction in light of current note
136138
b.sendMsg("/predictor/predict",
137-
\pitch, num, \time, dt, \vel, val,
139+
\inst, inst, \pitch, num, \time, dt, \vel, val,
138140
\allow_start, false, \allow_end, false,
139141
\pitch_temp, 0.5, \rhythm_temp, 0.5, \timing_temp, 0.1,
140142
\min_time, ~delay, \max_time, 5
@@ -150,7 +152,7 @@ MIDIdef.noteOn(\input_on, {
150152
~synths[num] = Synth(\pluck, [\freq, num.midicps, \vel, val/127]);//.release(1);
151153

152154
// post the current note
153-
[\player, dt, num, val].postln;
155+
[\player, dt, inst, num, val].postln;
154156

155157
// mark time of current note
156158
t = t2;
@@ -165,31 +167,33 @@ MIDIdef.noteOn(\input_on, {
165167
// OSC return from python
166168
OSCdef(\return, {
167169
arg msg, time, addr, recvPort;
168-
var num = msg[1]; // MIDI number of predicted note
169-
var dt = msg[2]; // time to predicted note
170-
var val = msg[3]; // velocity 0-127
171-
var step = msg[4];
170+
var inst = msg[1]; // instrument of predicted note
171+
var pitch = msg[2]; // MIDI number of predicted note
172+
var dt = msg[3]; // time to predicted note
173+
var val = msg[4]; // velocity 0-127
174+
var end = msg[5];
175+
var step = msg[6];
172176

173177
// time-to-next note gets 'censored' by the model
174178
// when over a threshold, in this case 10 seconds,
175179
// meaning it just predicts 10s rather than any longer time
176180
var censor = dt>=10.0;
177181

178182
~pending_predictions = ~pending_predictions-1;
179-
[\step, step].postln;
183+
// [\step, step].postln;
180184

181185

182186
censor.if{
183187
// if the predicted time is > 10 seconds, don't schedule it, just stop.
184188
\censor.postln;
185-
// ~synths[num]!?(_.release(3.0));
189+
// ~synths[pitch]!?(_.release(3.0));
186190
}{
187191
// schedule the predicted note
188192
SystemClock.sched(dt-~delay, {
189193
(~gate>0).if{
190194
var t2 = Process.elapsedTime;
191195
var dt_actual = t2 - t;
192-
(num==129).if{
196+
(end==1).if{
193197
// 129 is the 'stop token', meaning 'end-of-performance'
194198
// in this case don't schedule a note, and reset the model
195199
// b.sendMsg("/predictor/reset");
@@ -204,40 +208,36 @@ OSCdef(\return, {
204208
// (there shouldn't be any, but might
205209
// be if there was a lot of fast MIDI input)
206210
SystemClock.clear;
207-
~pending_predictions.postln;
211+
// ~pending_predictions.postln;
208212
// feed model its own prediction as input
209213
b.sendMsg("/predictor/predict",
210-
\pitch, num, \time, dt_actual, \vel, val,
214+
\inst, inst, \pitch, pitch, \time, dt_actual, \vel, val,
211215
\allow_start, false, \allow_end, true,
212216
\pitch_temp, 0.7, \rhythm_temp, 0.7, \timing_temp, 0.1,
213-
\min_time, ~delay*2, \max_time, 5,
214-
// \min_vel, 10
215-
// \fix_time, ((~step+1)%3==0).if{0.6}{0} // triads
216-
// \fix_time, (~step%8)*0.1 // specific rhythm
217-
217+
\min_time, ~delay*0, \max_time, 5,
218218
);
219219
~pending_predictions = ~pending_predictions+1;
220220

221221
// play the current note
222-
~synths[num]!?(_.release(0.05));
222+
~synths[pitch]!?(_.release(0.05));
223223
(val > 0).if{
224-
~synths[num] = Synth(\pluck, [\freq, num.midicps, \vel, val/127])
224+
~synths[pitch] = Synth(\pluck, [\freq, pitch.midicps, \vel, val/127])
225225
}{
226-
~synths[num] = nil
226+
~synths[pitch] = nil
227227
};
228228
// post the current note
229-
[\model, dt, num, val].postln;
229+
[\model, step, dt, inst, pitch, val, end].postln;
230230
// mark the actual time of current note
231231
t = t2;
232232
~machine_t = t;
233233
// crudely draw note on piano GUI
234234
~gui.if{
235-
AppClock.sched(0,{k.keyDown(num)});
236-
AppClock.sched(0.2,{k.keyUp(num)});
235+
AppClock.sched(0,{k.keyDown(pitch)});
236+
AppClock.sched(0.2,{k.keyUp(pitch)});
237237
}
238238
};
239239
~step = ~step+1;
240-
[\late, dt_actual-dt].postln;
240+
// [\late, dt_actual-dt].postln;
241241
}
242242
})};
243243

@@ -249,7 +249,8 @@ OSCdef(\return, {
249249
SystemClock.clear;
250250
~synths.do(_.release(1.0));
251251
b.sendMsg("/predictor/reset");
252-
{MIDIdef.all[\input].func.value(99, 60)}.defer(0.5);
252+
~player_inst = 83;
253+
{MIDIdef.all[\input_on].func.value(99, 60)}.defer(0.5);
253254
SystemClock.clear;
254255
)
255256
// b.sendMsg("/predictor/predict", \pitch, 70, \time, 0, \vel, 64);

examples/notepredictor/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _(address, **kw):
3939
print('no model loaded')
4040
else:
4141
r = predictor.predict(**kw)
42-
return '/prediction', r['pitch'], r['time'], r['velocity'], r['step']
42+
return '/prediction', r['instrument'], r['pitch'], r['time'], r['velocity'], r['end'], r['step']
4343

4444
elif cmd=="reset":
4545
if predictor is None:

notepredictor/notepredictor/data.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def __init__(self, data_dir, batch_len, transpose=2, speed=0.1, glob='**/*.pkl')
1717
self.transpose = transpose
1818
self.speed = speed
1919
self.start_token = 128
20-
self.end_token = 129
20+
self.n_anon = 8
21+
self.prog_start_token = 0
2122
# self.clamp_time = clamp_time
2223

2324
def __len__(self):
@@ -26,11 +27,29 @@ def __len__(self):
2627
def __getitem__(self, idx):
2728
f = self.files[idx]
2829
item = torch.load(f)
30+
program = item['program'] # 1-d LongTensor of MIDI programs 0-255
31+
# (128-255 are drums)
2932
pitch = item['pitch'] # 1-d LongTensor of MIDI pitches 0-127
3033
time = item['time']
3134
velocity = item['velocity']
35+
3236
assert len(pitch) == len(time)
3337

38+
unique_melodic = program.masked_select(program<128).unique()
39+
unique_drum = program.masked_select(program>=128).unique()
40+
41+
# # randomly map instruments to 'anonymous melodic' and 'anonymous drum'
42+
for pr in unique_melodic:
43+
if torch.rand((1,)) < 0.1:
44+
r = torch.randint(self.n_anon, size=(1,))
45+
program[program==pr] = r+256
46+
for pr in unique_drum:
47+
if torch.rand((1,)) < 0.1:
48+
r = torch.randint(self.n_anon, size=(1,))
49+
program[program==pr] = r+256+self.n_anon
50+
# shift from 0-index to general MIDI 1-index; use 0 for start token
51+
program += 1
52+
3453
# random transpose avoiding out of range notes
3554
transpose_down = min(self.transpose, pitch.min().item())
3655
transpose_up = min(self.transpose, 127-pitch.max())
@@ -39,25 +58,35 @@ def __getitem__(self, idx):
3958

4059
# random speed
4160
# delta t of first note?
61+
time = time.float()
4262
time = time * (1 + random.random()*self.speed*2 - self.speed)
4363
# dequantize
4464
# TODO: use actual tactus from MIDI file?
4565
time = (
4666
time + (torch.rand_like(time)-0.5)*2e-3
4767
).clamp(0., float('inf'))
4868

49-
# TODO: random velocity curve?
69+
# dequantize velocity
70+
velocity = velocity.float()
5071
velocity = (
5172
velocity +
5273
(torch.rand_like(time)-0.5) * ((velocity>0) & (velocity<127)).float()
5374
).clamp(0., 127.)
75+
# random velocity curve
76+
velocity /= 127
77+
velocity = velocity ** (2**(torch.randn((1,))/3))
78+
velocity *= 127
5479

55-
# pad with start, end tokens
56-
pad = max(1, self.batch_len-len(pitch))
80+
# pad with start tokens, zeros
81+
pad = max(0, self.batch_len-len(pitch))
82+
program = torch.cat((
83+
program.new_full((1,), self.prog_start_token),
84+
program,
85+
program.new_zeros((pad,))))
5786
pitch = torch.cat((
5887
pitch.new_full((1,), self.start_token),
5988
pitch,
60-
pitch.new_full((pad,), self.end_token)))
89+
pitch.new_zeros((pad,))))
6190
time = torch.cat((
6291
time.new_zeros((1,)),
6392
time,
@@ -66,16 +95,21 @@ def __getitem__(self, idx):
6695
velocity.new_zeros((1,)),
6796
velocity,
6897
velocity.new_zeros((pad,))))
98+
# end signal: nonzero for last event
99+
end = torch.zeros_like(program)
100+
end[-1] = 1
69101

70102
# random slice
71103
i = random.randint(0, len(pitch)-self.batch_len)
104+
program = program[i:i+self.batch_len]
72105
pitch = pitch[i:i+self.batch_len]
73106
time = time[i:i+self.batch_len]
74107
velocity = velocity[i:i+self.batch_len]
75-
76-
# time = time.clamp(*self.clamp_time)
108+
end = end[i:i+self.batch_len]
77109

78110
return {
111+
'end':end,
112+
'instrument':program,
79113
'pitch':pitch,
80114
'time':time,
81115
'velocity':velocity

0 commit comments

Comments
 (0)