Skip to content

Commit e163157

Browse files
committed
allow sty.n_workers == 0
1 parent f09f384 commit e163157

File tree

4 files changed

+15
-17
lines changed

4 files changed

+15
-17
lines changed

src/studies/TGLF_database.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function _run(study::StudyTGLFdb)
6868
sty = study.sty
6969
act = study.act
7070

71-
@assert sty.n_workers == length(Distributed.workers()) "The number of workers = $(length(Distributed.workers())) isn't the number of workers you requested = $(sty.n_workers)"
71+
@assert (sty.n_workers == 0 || sty.n_workers == length(Distributed.workers())) "The number of workers = $(length(Distributed.workers())) isn't the number of workers you requested = $(sty.n_workers)"
7272
@assert ismissing(getproperty(sty, :sat_rules, missing)) ismissing(getproperty(sty, :custom_tglf_models, missing)) "Specify either sat_rules or custom_tglf_models"
7373

7474
cases_files = [

src/studies/database_generator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ Runs the DatabaseGenerator with sty settings in parallel on designated cluster
6969
function _run(study::StudyDatabaseGenerator)
7070
sty = study.sty
7171

72-
@assert sty.n_workers == length(Distributed.workers()) "The number of workers = $(length(Distributed.workers())) isn't the number of workers you requested = $(sty.n_workers)"
72+
@assert (sty.n_workers == 0 || sty.n_workers == length(Distributed.workers())) "The number of workers = $(length(Distributed.workers())) isn't the number of workers you requested = $(sty.n_workers)"
7373

7474
if typeof(study.ini) <: ParametersAllInits && typeof(study.act) <: ParametersAllActors
7575
iterator = collect(1:sty.n_simulations)

src/studies/experiment_postdictive.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@ Base.@kwdef mutable struct FUSEparameters__ParametersStudyPostdictive{T<:Real} <
1818
n_workers::Entry{Int} = study_common_parameters(; n_workers=missing)
1919
release_workers_after_run::Entry{Bool} = study_common_parameters(; release_workers_after_run=true)
2020
save_folder::Entry{String} = Entry{String}("-", "Folder to save the postdictive runs into")
21+
kw_case_parameters::Entry{Dict{Symbol,Any}} = Entry{Dict{Symbol,Any}}("-", "Keyword arguments passed to case_parameters"; default=Dict{Symbol,Any}())
2122

2223
# Postdictive-specific parameters
2324
device::Entry{Symbol} = Entry{Symbol}("-", "Device to run postdictive simulations for")
2425
shots::Entry{Vector{Int}} = Entry{Vector{Int}}("-", "List of shot numbers")
25-
fit_profiles::Entry{Bool} = Entry{Bool}("-", "Whether to fit profiles in case_parameters"; default=true)
2626
reconstruction::Entry{Bool} = Entry{Bool}("-", "Run postdiction in reconstruction mode")
27-
use_local_cache::Entry{Bool} = Entry{Bool}("-", "Whether to use local cache in case_parameters"; default=false)
2827
end
2928

3029
mutable struct StudyPostdictive{T<:Real} <: AbstractStudy
@@ -47,12 +46,12 @@ Runs the Postdictive study with sty settings in parallel on designated cluster
4746
function _run(study::StudyPostdictive)
4847
sty = study.sty
4948

50-
@assert sty.n_workers == length(Distributed.workers()) "The number of workers = $(length(Distributed.workers())) isn't the number of workers you requested = $(sty.n_workers)"
49+
@assert (sty.n_workers == 0 || sty.n_workers == length(Distributed.workers())) "The number of workers = $(length(Distributed.workers())) isn't the number of workers you requested = $(sty.n_workers)"
5150

5251
# parallel run
5352
println("running $(length(sty.shots)) postdictive simulations with $(sty.n_workers) workers on $(sty.server)")
5453

55-
ProgressMeter.@showprogress map(shot -> run_postdictive_case(study, shot), sty.shots)
54+
ProgressMeter.@showprogress map(shot -> run_postdictive_case(study, shot; sty.kw_case_parameters), sty.shots)
5655

5756
# Release workers after run
5857
if sty.release_workers_after_run
@@ -64,11 +63,11 @@ function _run(study::StudyPostdictive)
6463
end
6564

6665
"""
67-
run_postdictive_case(study::StudyPostdictive, shot::Int)
66+
run_postdictive_case(study::StudyPostdictive, shot::Int; kw_case_parameters::Dict{Symbol,Any})
6867
6968
Run a single postdictive case for a given device and shot
7069
"""
71-
function run_postdictive_case(study::StudyPostdictive, shot::Int)
70+
function run_postdictive_case(study::StudyPostdictive, shot::Int; kw_case_parameters::Dict{Symbol,Any})
7271
sty = study.sty
7372
device = sty.device
7473

@@ -91,7 +90,7 @@ function run_postdictive_case(study::StudyPostdictive, shot::Int)
9190
redirect_stderr(file_log)
9291
cd(savedir)
9392

94-
run_postdictive_case(device, shot; user_act=study.act, sty.fit_profiles, sty.use_local_cache, savedir, sty.reconstruction)
93+
run_postdictive_case(device, shot; user_act=study.act, savedir, sty.reconstruction, kw_case_parameters)
9594

9695
# catch e
9796
# if isa(e, InterruptException)
@@ -106,10 +105,10 @@ function run_postdictive_case(study::StudyPostdictive, shot::Int)
106105
end
107106
end
108107

109-
function run_postdictive_case(device::Symbol, shot::Int; kw...)
108+
function run_postdictive_case(device::Symbol, shot::Int; kw_case_parameters::Dict{Symbol,Any}, kw...)
110109
dd = IMAS.dd()
111110
dd_exp = IMAS.dd()
112-
run_postdictive_case!(dd, dd_exp, device, shot; kw...)
111+
run_postdictive_case!(dd, dd_exp, device, shot; kw_case_parameters, kw...)
113112
return (dd=dd, dd_exp=dd_exp)
114113
end
115114

@@ -119,15 +118,14 @@ function run_postdictive_case!(
119118
device::Symbol,
120119
shot::Int;
121120
user_act::ParametersActors,
122-
fit_profiles::Bool,
123-
use_local_cache::Bool,
124121
savedir::AbstractString=abspath("."),
125-
reconstruction::Bool
122+
reconstruction::Bool,
123+
kw_case_parameters::Dict{Symbol,Any}
126124
)
127125

128126
# Get case parameters
129-
@info "case_parameters($(repr(device)), $shot; fit_profiles=$fit_profiles, use_local_cache=$use_local_cache)"
130-
ini, act = FUSE.case_parameters(device, shot; fit_profiles, use_local_cache)
127+
@info "case_parameters($(repr(device)), $shot; $(repr(kw_case_parameters))...)"
128+
ini, act = FUSE.case_parameters(device, shot; kw_case_parameters...)
131129

132130
# Override act with user-specific actor parameters
133131
#merge!(act, user_act)

src/studies/multi_objective_optimization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262
function _run(study::StudyMultiObjectiveOptimizer)
6363
sty = study.sty
6464

65-
@assert sty.n_workers == length(Distributed.workers()) "The number of workers = $(length(Distributed.workers())) isn't the number of workers you requested = $(sty.n_workers)"
65+
@assert (sty.n_workers == 0 || sty.n_workers == length(Distributed.workers())) "The number of workers = $(length(Distributed.workers())) isn't the number of workers you requested = $(sty.n_workers)"
6666
@assert iseven(sty.population_size) "Population size must be even"
6767

6868
if sty.restart_workers_after_n_generations > 0

0 commit comments

Comments
 (0)