Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit a9b91f1

Browse files
pitch top p; separate rhythm top p from timing temp; velocity min/max
1 parent d6cf92c commit a9b91f1

File tree

4 files changed

+119
-46
lines changed

4 files changed

+119
-46
lines changed

examples/notepredictor/generate.scd

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// in this example the model's predictions are fed back to it so it plays itself.
22
// the player can add notes as well and start/stop/reset the model with a footswitch.
33

4+
// TODO: steerable generation. gui for ranges, temperatures;
5+
// MIDI controller for pitch set
6+
47
(
58
~gui = false;
69
MIDIIn.connectAll;
@@ -18,26 +21,30 @@ s.boot;
1821
(
1922
SynthDef(\pluck, {
2023
var vel = \vel.kr;
21-
var signal = Saw.ar(\freq.kr, 0.2) * EnvGate.new(1);
22-
var fr = 2.pow(Decay.ar(Impulse.ar(0), 3)*6*vel+8);
24+
var freq = \freq.kr;
25+
var fl = freq.log2 - 1;
26+
var signal = Saw.ar(freq, 0.2) * EnvGate.new(1);
27+
var fr = 2.pow(Decay.ar(Impulse.ar(0), 3)*(13-fl)*vel+fl);
2328
signal = BLowPass.ar(signal, fr)*vel;
2429
Out.ar([0,1], signal);
2530
}).add
2631
)
2732

28-
2933
// measure round-trip latency
3034
(
3135
OSCdef(\return, {
3236
arg msg, time, addr, recvPort;
3337
(Process.elapsedTime - t).postln;
3438
}, '/prediction', nil);
3539
t = Process.elapsedTime;
36-
b.sendMsg("/predictor/predict", \pitch, 60+12.rand, \time, 0, \vel, 0);
40+
b.sendMsg("/predictor/predict",
41+
\pitch, 60+12.rand, \time, 0, \vel, 0,
42+
\pitch_temp, 0.5, \rhythm_temp, 0.5, \timing_temp, 0.1
43+
);
3744
)
3845

3946
// set the delay for more precise timing
40-
~delay = 0.015;
47+
~delay = 0.016;
4148

4249
// duet with the model
4350
// feeds the model's predictions back to it as well as player input
@@ -71,7 +78,7 @@ MIDIdef.program(\switch, {
7178
MIDIdef.noteOn(\input, {
7279
arg val, num, chan, src;
7380
var t2 = Process.elapsedTime;
74-
var dt = t2-(t?t2); //time since last note
81+
var dt = t2-(t?(t2-~delay)); //time since last note
7582

7683
// cancel any pending predictions
7784
SystemClock.clear;
@@ -80,12 +87,13 @@ MIDIdef.noteOn(\input, {
8087
b.sendMsg("/predictor/predict",
8188
\pitch, num, \time, dt, \vel, val,
8289
\allow_start, false, \allow_end, false,
83-
\time_temp, 0, \min_time, 0.1, \max_time, 5
84-
// \fix_time, 9
90+
\pitch_temp, 0.5, \rhythm_temp, 0.5, \timing_temp, 0.1,
91+
\min_time, ~delay, \max_time, 5
92+
// \fix_time, ~delay
8593
);
8694

8795
// release the previous note
88-
y.release(0.1);
96+
y.release(0.05);
8997

9098
// play the current note
9199
y = Synth(\pluck, [\freq, num.midicps, \vel, val/127]);//.release(1);
@@ -98,9 +106,9 @@ MIDIdef.noteOn(\input, {
98106
~player_t = t;
99107

100108
~step = ~step + 1;
109+
// ~step = 0;
101110
});
102111

103-
104112
// OSC return from python
105113
OSCdef(\return, {
106114
arg msg, time, addr, recvPort;
@@ -111,7 +119,7 @@ OSCdef(\return, {
111119
// time-to-next note gets 'censored' by the model
112120
// when over a threshold, in this case 10 seconds,
113121
// meaning it just predicts 10s rather than any longer time
114-
var censor = dt>10.0;
122+
var censor = dt>=10.0;
115123

116124
censor.if{
117125
// if the predicted time is > 10 seconds, don't schedule it, just stop.
@@ -125,12 +133,13 @@ OSCdef(\return, {
125133
(num==129).if{
126134
// 129 is the 'stop token', meaning 'end-of-performance'
127135
// in this case don't schedule a note, and reset the model
128-
b.sendMsg("/predictor/reset");
136+
// b.sendMsg("/predictor/reset");
129137
//release the last note
130138
y.release(1.0);
131139
// unset time so next note will have dt=0
132-
t = nil;
133-
\reset.postln
140+
// t = nil;
141+
// \reset.postln
142+
\end.postln;
134143
}{
135144
// cancel any pending predictions
136145
// (there shouldn't be any, but might
@@ -139,9 +148,11 @@ OSCdef(\return, {
139148
// feed model its own prediction as input
140149
b.sendMsg("/predictor/predict",
141150
\pitch, num, \time, dt_actual, \vel, val,
142-
\allow_start, false, \allow_end, false,
143-
\time_temp, 0.1, \min_time, 0.1, \max_time, 5
144-
// \fix_time, (~step%4==0).if{0.6}{0} // tetrachords
151+
\allow_start, false, \allow_end, true,
152+
\pitch_temp, 0.7, \rhythm_temp, 0.7, \timing_temp, 0.1,
153+
\min_time, ~delay*2, \max_time, 5,
154+
\min_vel, 10
155+
// \fix_time, ((~step+1)%3==0).if{0.6}{0} // triads
145156
// \fix_time, (~step%8)*0.1 // specific rhythm
146157

147158
);
@@ -151,7 +162,7 @@ OSCdef(\return, {
151162
y.release(1.0)
152163
}{
153164
// otherwise release fast to play a melody
154-
y.release(0.1)
165+
y.release(0.05)
155166
};
156167
// play the current note
157168
y = Synth(\pluck, [
@@ -175,7 +186,15 @@ OSCdef(\return, {
175186
}, "/prediction", nil);
176187
)
177188

189+
190+
(
178191
// send a note manually if you don't have a MIDI controller:
192+
SystemClock.clear;
193+
y.release(0.2);
194+
b.sendMsg("/predictor/reset");
195+
{MIDIdef.all[\input].func.value(99, 60)}.defer(0.5);
196+
SystemClock.clear;
197+
)
179198
// b.sendMsg("/predictor/predict", \pitch, 70, \time, 0, \vel, 64);
180199

181200
// load another model

iipyper/iipyper/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ async def _run_async():
2525
# start OSC server
2626
for osc in OSC.instances:
2727
await osc.create_server(asyncio.get_event_loop())
28-
# osc.create_client()
2928

3029
for midi in MIDI.instances:
3130
asyncio.create_task(midi_coroutine(midi))
32-
# asyncio.create_task(midi.get_coroutine())
3331

3432
# start loop tasks
3533
if len(_loop_fns):

notepredictor/notepredictor/distributions.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,21 @@
66
import torch.distributions as D
77
import torch.nn.functional as F
88

9+
def reweight_top_p(probs, top_p):
10+
"""given tensor of probabilities, apply top p / "nucleus" filtering"""
11+
# NOTE: this is fudged slightly, it doesn't 'interpolate' the cutoff bin
12+
desc_probs, idx = probs.sort(-1, descending=True)
13+
iidx = idx.argsort(-1)
14+
cumprob = desc_probs.cumsum(-1)
15+
# first index where cumprob >= top_p is the last index we don't zero
16+
to_zero = (cumprob >= top_p).roll(1, -1)
17+
to_zero[...,0] = False
18+
# unsort
19+
to_zero = to_zero.gather(-1, iidx)
20+
weighted_probs = torch.zeros_like(probs).where(to_zero, probs)
21+
return weighted_probs / weighted_probs.sum(-1, keepdim=True)
22+
23+
924
class CensoredMixtureLogistic(nn.Module):
1025
def __init__(self, n, res=1e-2, lo='-inf', hi='inf',
1126
sharp_bounds=(1e-4,2e3), init=None):
@@ -113,16 +128,22 @@ def cdf(self, h, x):
113128

114129
def cdf_components(self, loc, s, x):
115130
x_ = (x[...,None] - loc) * s
116-
return x_.sigmoid()
131+
return x_.sigmoid()
117132

118-
def sample(self, h, truncate=None, shape=None, temp=None, bias=None):
133+
# TODO: 'discrete_sample' method which would re-quantize and then allow
134+
# e.g. nucleus sampling on the categorical distribution?
135+
def sample(self, h, truncate=None, shape=None,
136+
weight_top_p=None, component_temp=None, bias=None):
119137
"""
120138
Args:
121139
h: Tensor[...,n_params]
122140
truncate: Optional[Tuple[2]]. lower and upper bound for truncation.
123141
shape: Optional[int]. additional sample shape to be prepended to dims.
124-
temp: Optional[float]. pseudo-temperature (temperature of each mixture
125-
component). default is 1. 0 would sample component location only,
142+
weight_top_p: top_p ("nucleus") filtering for mixture weights.
143+
default is 1 (no change to distribution). 0 would sample top
144+
component (after truncation) only.
145+
component_temp: Optional[float]. sampling temperature of each mixture
146+
component. default is 1. 0 would sample component location only,
126147
ignoring sharpness.
127148
bias: applied outside of truncation but inside of clamping,
128149
useful e.g. for latency correction when sampling delta-time
@@ -139,20 +160,31 @@ def sample(self, h, truncate=None, shape=None, temp=None, bias=None):
139160
truncate = (-np.inf, np.inf)
140161
truncate = torch.tensor(truncate)
141162

142-
if temp is None:
143-
temp = 1
163+
if component_temp is None:
164+
component_temp = 1
144165

145166
if bias is None:
146167
bias = 0
147168

148169
log_pi, loc, s = self.get_params(h)
170+
s = s/component_temp
149171
scale = 1/s
150172

151173
# cdfs: [...,bound,component]
152174
cdfs = self.cdf_components(loc[...,None,:], s[...,None,:], truncate)
153175
# prob. mass of each component witin bounds
154176
trunc_probs = cdfs[...,1,:] - cdfs[...,0,:] # [...,component]
155177
probs = log_pi.exp() * trunc_probs # reweighted mixture component probs
178+
if weight_top_p is not None:
179+
# reweight with top_p
180+
probs = reweight_top_p(probs, weight_top_p)
181+
182+
## DEBUG
183+
# print(loc)
184+
# print(s)
185+
# print(trunc_probs)
186+
# print(probs)
187+
#, log_pi.exp(), trunc_probs)
156188

157189
c = D.Categorical(probs).sample((shape,))
158190
# move sample dimension first
@@ -166,7 +198,7 @@ def sample(self, h, truncate=None, shape=None, temp=None, bias=None):
166198
u = u * (upper-lower) + lower
167199

168200
# x = loc + scale * (u.log() - (1 - u).log())
169-
x = loc + bias - scale * temp * (1/u - 1).log()
201+
x = loc + bias - scale * (1/u - 1).log()
170202
x = x.clamp(self.lo, self.hi)
171203
return x[0] if unwrap else x
172204

notepredictor/notepredictor/model.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.distributions as D
99

1010
from .rnn import GenericRNN
11-
from .distributions import CensoredMixtureLogistic
11+
from .distributions import CensoredMixtureLogistic, reweight_top_p
1212

1313
class SineEmbedding(nn.Module):
1414
def __init__(self, n, w0=1e-3, interval=1.08):
@@ -211,7 +211,9 @@ def cell_state(self):
211211

212212
def get_samplers(self,
213213
pitch_topk=None, index_pitch=None, allow_start=False, allow_end=False,
214-
sweep_time=False, min_time=None, max_time=None, bias_time=None, time_temp=None):
214+
pitch_top_p=None,
215+
sweep_time=False, min_time=None, max_time=None, bias_time=None, time_weight_top_p=None, time_component_temp=None,
216+
min_vel=None, max_vel=None):
215217
"""
216218
this method converts the many arguments to `predict` into functions for
217219
sampling each note modality (e.g. pitch, time, velocity)
@@ -227,7 +229,10 @@ def sample_pitch(x):
227229
elif pitch_topk is not None:
228230
return x.argsort(-1, True)[...,:pitch_topk].transpose(0,-1)
229231
else:
230-
return D.Categorical(logits=x).sample()
232+
probs = x.softmax(-1)
233+
if pitch_top_p is not None:
234+
probs = reweight_top_p(probs, pitch_top_p)
235+
return D.Categorical(probs).sample()
231236

232237
def sample_time(x):
233238
# TODO: respect trunc_time when sweep_time is True
@@ -247,12 +252,19 @@ def sample_time(x):
247252
-np.inf if min_time is None else min_time,
248253
np.inf if max_time is None else max_time)
249254
return self.time_dist.sample(x,
250-
truncate=trunc, temp=time_temp, bias=bias_time)
255+
truncate=trunc, bias=bias_time,
256+
component_temp=time_component_temp, weight_top_p=time_weight_top_p)
257+
258+
def sample_velocity(x):
259+
trunc = (
260+
-np.inf if min_vel is None else min_vel,
261+
np.inf if max_vel is None else max_vel)
262+
return self.vel_dist.sample(x, truncate=trunc)
251263

252264
return (
253265
sample_pitch,
254266
sample_time,
255-
lambda x: self.vel_dist.sample(x),
267+
sample_velocity,
256268
)
257269

258270
@property
@@ -340,12 +352,14 @@ def forward(self, pitches, times, velocities, validation=False):
340352
)
341353
return r
342354

343-
# TODO: force
355+
# TODO: remove pitch_topk and sweep_time?
344356
def predict(self,
345357
pitch, time, vel,
346358
fix_pitch=None, fix_time=None, fix_vel=None,
347359
pitch_topk=None, index_pitch=None, allow_start=False, allow_end=False,
348-
sweep_time=False, min_time=None, max_time=None, bias_time=None, time_temp=None):
360+
sweep_time=False, min_time=None, max_time=None, bias_time=None,
361+
pitch_temp=None, rhythm_temp=None, timing_temp=None,
362+
min_vel=None, max_vel=None):
349363
"""
350364
consume the most recent note and return a prediction for the next note.
351365
@@ -372,19 +386,25 @@ def predict(self,
372386
bias_time: add this delay to the time
373387
(after applying min/max but before clamping to 0).
374388
may be useful for latency correction.
375-
time_temp: if not None, apply pseudo-temperature to the time distribution.
376-
i.e., scale the temperature of each mixture component.
377-
this is not technically the same as changing the temperature of the whole
378-
time distribution, but it can be useful if we assume each component
379-
corresponds to a different rhythmic interval. then passing `time_temp=0`
380-
would lead to more rhythmically steady, less random playing.
389+
pitch_temp: if not None, apply top_p sampling to pitch. 0 is
390+
deterministic, 1 is 'natural' according to the model
391+
rhythm_temp: if not None, apply top_p sampling to the weighting
392+
of mixture components. this affects coarse rhythmic patterns; 0 is
393+
deterministic, 1 is 'natural' according to the model
394+
timing_temp: if not None, apply temperature sampling to the time
395+
component. this affects fine timing; 0 is deterministic and precise,
396+
1 is 'natural' according to the model.
397+
min_vel, max_vel: if not None, truncate the velocity distribution
381398
382399
Returns: dict of
383400
'pitch': int. predicted MIDI number of next note.
384401
'time': float. predicted time to next note.
385402
'velocity': float. unquantized predicted velocity of next note.
386403
'*_params': tensor. distrubution parameters for visualization purposes.
387404
"""
405+
if (index_pitch is not None) and (pitch_temp is not None):
406+
print("warning: `index pitch` overrides `pitch_temp`")
407+
388408
with torch.no_grad():
389409
pitch = torch.LongTensor([[pitch]]) # 1x1 (batch, time)
390410
time = torch.FloatTensor([[time]]) # 1x1 (batch, time)
@@ -409,7 +429,10 @@ def predict(self,
409429
self.projections,
410430
self.get_samplers(
411431
pitch_topk, index_pitch, allow_start, allow_end,
412-
sweep_time, min_time, max_time, bias_time, time_temp),
432+
pitch_temp,
433+
sweep_time, min_time, max_time, bias_time,
434+
rhythm_temp, timing_temp,
435+
min_vel, max_vel),
413436
self.embeddings,
414437
))
415438

@@ -431,10 +454,11 @@ def predict(self,
431454
for i,(item, embed) in enumerate(zip(fix, self.embeddings)):
432455
if item is None:
433456
if (
434-
i==1 and (sweep_time
435-
or (min_time is not None) or (max_time is not None)
436-
or (time_temp is not None)) or
437-
i==0 and pitch_topk
457+
i==0 and (pitch_topk or pitch_temp is not None) or
458+
i==1 and any(p is not None for p in (
459+
min_time, max_time, rhythm_temp, timing_temp)) or
460+
i==2 and any(p is not None for p in (
461+
min_vel, max_vel))
438462
):
439463
cons_idx.append(i)
440464
else:
@@ -449,7 +473,7 @@ def predict(self,
449473
iperm = np.argsort(perm) # inverse permutation back to canonical order
450474

451475
md = ['pitch', 'time', 'vel']
452-
print([md[i] for i in perm])
476+
print('sampling order:', [md[i] for i in perm])
453477

454478
# for each undetermined modality,
455479
# sample a new value conditioned on alteady determined ones

0 commit comments

Comments
 (0)