Skip to content

Commit 43caf12

Browse files
committed
Updated MRNA degradation example
1 parent 13caa39 commit 43caf12

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

mrnad/gen.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" Generate training, validation and test set data for the model of mRNA turnover """
1+
# Generate training, validation and test set data for the model of mRNA turnover
22
using Sobol
33
using JLD2
44

@@ -46,7 +46,7 @@ end
4646

4747
# Convert reaction network into a JumpProblem for use with the SSA
4848
jsys = convert(JumpSystem, rn, combinatoric_ratelaw=false)
49-
dprob = DiscreteProblem(jsys, u0, (0.0, last(ts)), zeros(Float64, numreactionparams(rn)))
49+
dprob = DiscreteProblem(jsys, u0, (0.0, last(ts)), zeros(Float64, 18))
5050
jprob = JumpProblem(jsys, dprob, Direct(), save_positions=(false, false))
5151

5252
# Full-length mRNA (A + B + BC1 + ... + BC5 + C + D + E + F)
@@ -61,11 +61,11 @@ seq = LogSobolSeq(ranges[:,1], ranges[:,2])
6161

6262
@time train_pts = [ Sobol.next!(seq) for i in 1:100000 ]
6363
@time valid_pts = [ Sobol.next!(seq) for i in 1:100 ]
64-
@time test_pts = [ Sobol.next!(seq) for i in 1:5000 ]
64+
@time test_pts = [ Sobol.next!(seq) for i in 1:1000 ]
6565

66+
X_test, y_test = build_dataset(ts, test_pts, solver_accurate)
67+
@save joinpath(MODEL_DIR, "test_data.jld2") X_test y_test
6668
X_train, y_train = build_dataset(ts, train_pts, solver)
6769
@save joinpath(MODEL_DIR, "train_data.jld2") X_train y_train
6870
X_valid, y_valid = build_dataset(ts, valid_pts, solver_accurate)
6971
@save joinpath(MODEL_DIR, "valid_data.jld2") X_valid y_valid
70-
X_test, y_test = build_dataset(ts, test_pts, solver_accurate)
71-
@save joinpath(MODEL_DIR, "test_data.jld2") X_test y_test

mrnad/plot.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ u0[1] = 1
3535

3636
# Set up the SSA simulations
3737
jsys = convert(JumpSystem, rn, combinatoric_ratelaw=false)
38-
dprob = DiscreteProblem(jsys, u0, (0.0, last(ts)), zeros(Float64, numreactionparams(rn)))
38+
dprob = DiscreteProblem(jsys, u0, (0.0, last(ts)), zeros(Float64, 18))
3939
jprob = JumpProblem(jsys, dprob, Direct(), save_positions=(false, false))
4040

4141
# full-length mRNA (A + B + BC1 + ... + BC5 + C + D + E + F)
@@ -86,9 +86,9 @@ plt = plot(plt1, plt2, layout=l, size=(330, 120), bottom_margin=0Plots.mm, top_m
8686
# Predicted vs true moments
8787
# ---------------------------------------------------------------------------------------------------
8888

89-
# Plotting points only for t=1000
89+
# Plotting points only for t=500
9090
ind = 4
91-
m = 4
91+
m = 8
9292
m_NN = mean.(Distribution.(Ref(model), X_test[ind:m:end]))
9393
var_NN = var.(Distribution.(Ref(model), X_test[ind:m:end]))
9494

@@ -126,4 +126,4 @@ plt2 = annotate!(plt2, [(-1.5, -1.5, Plots.text("×10⁴", 6, :black, :center))]
126126
plot(plt1, plt2, size=(260, 130),
127127
left_margin=-1Plots.mm, bottom_margin=0Plots.mm, top_margin=-1Plots.mm, right_margin=0Plots.mm)
128128

129-
#savefig(joinpath(MODEL_DIR, "true_vs_predict_moments.svg"))
129+
savefig(joinpath(MODEL_DIR, "true_vs_predict_moments.svg"))

mrnad/train.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" Train Nessie using the data generated by mrnad/gen.jl """
1+
# Train Nessie using the data generated by mrnad/gen.jl
22
using JLD2
33

44
include("../train_NN.jl")

0 commit comments

Comments
 (0)