Skip to content

Commit 5d33d63

Browse files
update to Lux 1.0
1 parent 4789f2d commit 5d33d63

File tree

7 files changed

+35
-41
lines changed

7 files changed

+35
-41
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PianoHands"
22
uuid = "74435128-bd9d-4d82-978b-bd768beb391e"
33
authors = ["NeroBlackstone <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.2.0"
55

66
[deps]
77
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
@@ -26,7 +26,7 @@ test = ["Test"]
2626
CodecZlib = "0.7"
2727
IterTools = "1.10"
2828
JLD2 = "0.4"
29-
Lux = "0.5"
29+
Lux = "1.0"
3030
LuxCUDA = "0.3"
3131
MIDI = "2.7"
3232
MLUtils = "0.4"

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
# PianoHands.jl
22
Predicting hand assignments in piano MIDI using neural networks
33

4-
# Use Pre-trained weight
4+
# Use pre-trained model
55

66
``` julia
77
using PianoHands
88
generate_midi("./your_midi.mid";)
99
```
1010

11-
You will get a midi file `out.mid`, track 1 is left hand notes, track 2 is right hand notes.
11+
You will get a midi file `your_midi_out.mid`, track 1 is left hand notes, track 2 is right hand notes.
1212

13-
# Train Your own weight.
13+
# Train Your own model.
1414

1515
## Dataset preparation
1616

1717
Download PIG v1.2 Dataset to `PianoFingeringDataset` and remove duplicate fingering file, approximately 150 fingering files are required.
1818

1919
``` julia
20-
train_piano(DATASET_PATH,
20+
function train_piano(DATASET_PATH,
2121
TESTSET_PATH;
22-
BATCH_SIZE = 10,
23-
SEQ_LENGTH = 70,
22+
BATCH_SIZE = 12,
23+
SEQ_LENGTH = 75,
2424
HIDDEN_SIZE = 14,
25-
LEARNING_RATE = 0.0002f0,
25+
LEARNING_RATE = 0.0005f0,
2626
MAX_EPOCH = 200,
27-
EVALUATE_PER_N_TRAIN = 100)
27+
EVALUATE_PER_N_TRAIN = 50)
2828
```
2929

3030
The network structure is bi-directional GRU + Dense, and the hidden layer size can be adjusted by parameters. There is no stopping condition for training, you need stop manually.
@@ -33,8 +33,8 @@ Use trained weight:
3333

3434
```julia
3535
generate_midi(input_file::String;
36-
output_file="./out.mid",
37-
weight_file=pkgdir(PianoHands,"weight","weight-0.92757.jld2"),
36+
output_file="",
37+
weight_file=pkgdir(PianoHands,"model","model-0.91502.jld2"),
3838
HIDDEN_SIZE=14)
3939
```
4040

model/model-0.91502.jld2

21.8 KB
Binary file not shown.

src/data_processing.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ function get_train_dataloaders(dataset_path::String; batch_size=10, seq_length=2
6969
push!(feature_result,stack(stack.(partition(features,seq_length,1))))
7070
push!(label_result,stack(stack.(partition(labels,seq_length,1))))
7171
end
72-
return DataLoader((cat(feature_result...;dims=3),cat(label_result...;dims=2)); batchsize = batch_size, shuffle=true, parallel = true)
72+
return DataLoader((cat(feature_result...;dims=3),cat(label_result...;dims=2));
73+
batchsize = batch_size, shuffle=true, parallel = true)
7374
end
7475

7576
"""
@@ -96,7 +97,9 @@ Predict left hand or right hand by output.
9697
"""
9798
predict_y(y) = y > 0.5f0 ? 1 : 0
9899

99-
function generate_midi(input_file::String;output_file="./out.mid",weight_file=pkgdir(PianoHands,"weight","weight-0.92757.jld2"),HIDDEN_SIZE=14)
100+
function generate_midi(input_file::String; output_file::String="",
101+
weight_file=pkgdir(PianoHands,"model","model-0.91502.jld2"),HIDDEN_SIZE=14)
102+
100103
midi_file = load(input_file)
101104
hand_classify = inferance_midi(midi_file,weight_file,HIDDEN_SIZE)
102105

@@ -115,5 +118,5 @@ function generate_midi(input_file::String;output_file="./out.mid",weight_file=pk
115118
addnotes!(track_rh, notes_rh)
116119
addtrackname!(track_rh, "piano right")
117120
push!(new_midi_file.tracks, track_lh, track_rh)
118-
save(output_file, new_midi_file)
121+
save(isempty(output_file) ? first(splitext(input_file))*"_out.mid" : output_file, new_midi_file)
119122
end

src/training.jl

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,26 @@ end
77

88
function train_piano(DATASET_PATH,
99
TESTSET_PATH;
10-
BATCH_SIZE = 10,
11-
SEQ_LENGTH = 70,
10+
BATCH_SIZE = 12,
11+
SEQ_LENGTH = 75,
1212
HIDDEN_SIZE = 14,
13-
LEARNING_RATE = 0.0002f0,
13+
LEARNING_RATE = 0.0005f0,
1414
MAX_EPOCH = 200,
15-
EVALUATE_PER_N_TRAIN = 100)
15+
EVALUATE_PER_N_TRAIN = 50)
16+
17+
dev = gpu_device()
1618

1719
# Get the dataloaders
18-
train_loader = get_train_dataloaders(DATASET_PATH;batch_size=BATCH_SIZE, seq_length=SEQ_LENGTH)
20+
train_loader = get_train_dataloaders(DATASET_PATH;batch_size=BATCH_SIZE, seq_length=SEQ_LENGTH) |> dev
1921
val_x, val_y = get_val_datas(TESTSET_PATH)
2022

2123
# Create the model
2224
model = build_model(GRUCell,HIDDEN_SIZE)
2325
display(model)
2426
rng = Xoshiro(0)
25-
dev = gpu_device()
26-
train_state = Lux.Experimental.TrainState(rng, model, Adam(LEARNING_RATE); transform_variables=dev)
27+
28+
ps, st = Lux.setup(rng, model) |> dev
29+
train_state = Training.TrainState(model, ps, st,Adam(LEARNING_RATE))
2730

2831
logitbce = BinaryCrossEntropyLoss();
2932
loss_fn(ŷ,y) = sum(logitbce.(vec.(ŷ),eachslice(y;dims=1)))
@@ -40,28 +43,21 @@ function train_piano(DATASET_PATH,
4043
loss_sum = 0
4144
# Train the model
4245
for (x,y) in train_loader
43-
x = x |> dev
44-
y = y |> dev
45-
46-
(_, loss, _, train_state) = Lux.Experimental.single_train_step!(
46+
(_, loss, _, train_state) = Training.single_train_step!(
4747
AutoZygote(), compute_loss, (x, y), train_state)
48-
4948
i+=1
5049
loss_sum += loss
5150
if i % EVALUATE_PER_N_TRAIN == 0
5251
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss_sum/i
53-
5452
# Validate the model
5553
st_ = Lux.testmode(train_state.states)
5654
matchs = 0
5755
note_count = mapreduce(length,+,val_y)
5856
loss_sum_in = 0
5957
for (x, y) in zip(val_x,val_y)
60-
x = reshape(x, Val(3)) |> dev
58+
x = x |> dev
6159
y = y |> dev
62-
6360
ŷ, st_ = model(x, train_state.parameters, st_)
64-
6561
loss_sum_in += loss_fn(ŷ, y)
6662
matchs += matches_num(vcat(ŷ...),y)
6763
end
@@ -77,25 +73,20 @@ function train_piano(DATASET_PATH,
7773
heightest_acc = acc
7874
end
7975
end
80-
81-
8276
end
8377
i = 1
8478
end
8579
end
8680

8781
function inferance_midi(midi_file::MIDIFile,weight_file::String,HIDDEN_SIZE::Int)::Vector{Int}
88-
f = midi_to_features(midi_file)
89-
x = reshape(stack(f), Val(3))
90-
9182
model = build_model(GRUCell,HIDDEN_SIZE)
9283
display(model)
9384
dev = gpu_device()
9485
@load weight_file ps_trained st_trained
9586
ps_trained,st_trained |> dev
9687

9788
st_ = Lux.testmode(st_trained)
98-
y, st_ = model(x, ps_trained, st_)
89+
y, st_ = model(stack(midi_to_features(midi_file)), ps_trained, st_)
9990
y |> cpu_device()
10091
return (predict_y first).(y)
10192
end

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ using Test,PianoHands,MIDI,Lux,Random,Printf,LuxCUDA,Optimisers,Zygote,JLD2
33
@testset "pig to feature" begin
44
# train_piano("../PianoFingeringDataset/dataset/",
55
# "../PianoFingeringDataset/testset/";
6-
# SEQ_LENGTH=70,
7-
# BATCH_SIZE=10,
6+
# SEQ_LENGTH=75,
7+
# BATCH_SIZE=12,
88
# LEARNING_RATE = 0.0005f0,
99
# HIDDEN_SIZE = 14,
10-
# EVALUATE_PER_N_TRAIN = 100
10+
# EVALUATE_PER_N_TRAIN = 50
1111
# )
1212

13-
# generate_midi("./ymsn_full.mid";weight_file="./14trained_model-0.92757.jld2",HIDDEN_SIZE=14)
13+
generate_midi("./ymsn_full.mid";weight_file="../model/model-0.91502.jld2")
1414
end

weight/weight-0.92757.jld2

-19.8 KB
Binary file not shown.

0 commit comments

Comments
 (0)