Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 4293a49

Browse files
note offs in data prep
1 parent a9b91f1 commit 4293a49

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

notepredictor/notepredictor/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ def __getitem__(self, idx):
4141
# delta t of first note?
4242
time = time * (1 + random.random()*self.speed*2 - self.speed)
4343
# dequantize
44-
# TODO: use actual tactus from MIDI file
44+
# TODO: use actual tactus from MIDI file?
4545
time = (
4646
time + (torch.rand_like(time)-0.5)*2e-3
4747
).clamp(0., float('inf'))
4848

49+
# TODO: random velocity curve?
4950
velocity = (
5051
velocity +
5152
(torch.rand_like(time)-0.5) * ((velocity>0) & (velocity<127)).float()

notepredictor/scripts/lakh_prep.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,16 @@ def process(fnames, min_len=64):
2323
s_per_tick = micros_per_beat / mid.ticks_per_beat / 1e6
2424

2525
for i,tr in enumerate(mid.tracks):
26-
seq = [m for m in tr if m.type=='note_on' and m.velocity]
26+
seq = [m for m in tr if m.type=='note_on' or m.type=='note_off']
2727
if len(seq) < min_len:
2828
continue
2929
torch.save(dict(
3030
pitch=torch.LongTensor([m.note for m in seq]),
31-
velocity=torch.LongTensor([m.velocity for m in seq]),
31+
velocity=torch.LongTensor([
32+
m.velocity if m.type=='note_on' else 0 for m in seq]),
3233
time=torch.Tensor([m.time for m in seq])*s_per_tick,
33-
# src_track=i,
34-
# tempo=micros_per_beat,
34+
tempo=micros_per_beat,
35+
ticks=mid.ticks_per_beat
3536
), g.with_suffix(f'.{i}.pkl') )
3637

3738
def main(data_path, dest_path, n_jobs=4):

0 commit comments

Comments
 (0)