Skip to content

Commit 54cf4d1

Browse files
authored
setup yaml (#231)
* setup yaml * get settings * write and read work now
1 parent 8d83893 commit 54cf4d1

File tree

4 files changed

+49
-1
lines changed

4 files changed

+49
-1
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
2323
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
2424
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
2525
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
26+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2627
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
2728
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2829
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
@@ -31,6 +32,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
3132
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
3233
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3334
StyledStrings = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
35+
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
3436
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3537

3638
[weakdeps]
@@ -60,12 +62,14 @@ Makie = "0.22, 0.23, 0.24"
6062
NCDatasets = "0.14.8"
6163
NamedDims = "1.2.3"
6264
OptimizationOptimisers = "0.3.7"
65+
OrderedCollections = "1.8.1"
6366
PrettyTables = "2.4.0, 3.1.2"
6467
ProgressMeter = "1.10.4"
6568
Reexport = "1.2.2"
6669
Static = "1.3.1"
6770
Statistics = "1"
6871
StyledStrings = "1.0.3, 1.11.0"
72+
YAML = "0.4.16"
6973
Zygote = "0.7.7"
7074
julia = "1.10"
7175

src/EasyHybrid.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ using MLJ: partition
2626
using MLUtils: MLUtils, DataLoader, kfolds, numobs, rpad, splitobs
2727
using NCDatasets: NCDatasets, NCDataset, close, name
2828
using OptimizationOptimisers: OptimizationOptimisers, AdamW, Adam, Optimisers
29+
using OrderedCollections: OrderedDict
2930
using PrettyTables: PrettyTables
3031
using Printf: Printf, @sprintf
3132
using ProgressMeter: ProgressMeter, Progress, next!
3233
using Random: Random, AbstractRNG, randperm, randstring
3334
using Reexport: @reexport
3435
using Statistics: Statistics, mean, cor, quantile, var
3536
using StyledStrings: StyledStrings, @styled_str
37+
using YAML: load_file, write_file
3638
using Zygote: Zygote
3739
using Static: False, True
3840

@@ -66,5 +68,6 @@ include("utils/helpers_for_HybridModel.jl")
6668
include("utils/helpers_data_loading.jl")
6769
include("tune.jl")
6870
include("utils/helpers_cross_validation.jl")
71+
include("utils/config_yaml.jl")
6972

7073
end

src/train.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,14 @@ function train(
354354
train_diffs = !isempty(set_diff) ? NamedTuple{Tuple(set_diff)}([getproperty(ŷ_train, e) for e in set_diff]) : nothing
355355
val_diffs = !isempty(set_diff) ? NamedTuple{Tuple(set_diff)}([getproperty(ŷ_val, e) for e in set_diff]) : nothing
356356

357+
# collect all training arguments into a NamedTuple for easier saving and logging
358+
train_args = (; nepochs, batchsize, opt, patience, autodiff_backend, return_gradients, array_type, training_loss, loss_types, extra_loss, agg, train_from, random_seed, file_name, hybrid_name, return_model, monitor_names, folder_to_save, plotting, show_progress, yscale)
359+
360+
# get config settings and save to yaml file
361+
get_config_settings = get_full_config(hybridModel, train_args)
362+
path_yaml = joinpath(tmp_folder, "config_settings.yaml")
363+
save_hybrid_config(get_config_settings, path_yaml)
364+
357365
# TODO: save/output metrics
358366
return TrainResults(
359367
train_history,
@@ -366,7 +374,7 @@ function train(
366374
ps,
367375
st,
368376
best_epoch,
369-
best_agg_loss
377+
best_agg_loss,
370378
)
371379
end
372380

src/utils/config_yaml.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
export load_hybrid_config, save_hybrid_config
2+
export get_hybrid_config, get_train_config, get_full_config
3+
4+
function load_hybrid_config(path::String; dicttype = OrderedDict{String, Any})
5+
return load_file(path; dicttype)
6+
end
7+
8+
function save_hybrid_config(config::OrderedDict, path::String)
9+
return write_file(path, config)
10+
end
11+
12+
function get_hybrid_config(hm::LuxCore.AbstractLuxContainerLayer)
13+
hm_config = OrderedDict{String, Any}()
14+
for field in fieldnames(typeof(hm))
15+
hm_config[string(field)] = getfield(hm, field)
16+
end
17+
return hm_config
18+
end
19+
20+
function get_train_config(train_args::NamedTuple)
21+
train_config = OrderedDict{String, Any}()
22+
for field in fieldnames(typeof(train_args))
23+
train_config[string(field)] = getfield(train_args, field)
24+
end
25+
return train_config
26+
end
27+
28+
function get_full_config(hm::LuxCore.AbstractLuxContainerLayer, train_args::NamedTuple)
29+
full_config = OrderedDict{String, Any}()
30+
full_config["hybrid_model"] = get_hybrid_config(hm)
31+
full_config["train_args"] = get_train_config(train_args)
32+
return full_config
33+
end

0 commit comments

Comments
 (0)