diff --git a/Project.toml b/Project.toml index d36cc6ea..1ea7ad8e 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab" NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" @@ -31,6 +32,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StyledStrings = "f489334b-da3d-4c2e-b8f0-e476e12c162b" +YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] @@ -60,12 +62,14 @@ Makie = "0.22, 0.23, 0.24" NCDatasets = "0.14.8" NamedDims = "1.2.3" OptimizationOptimisers = "0.3.7" +OrderedCollections = "1.8.1" PrettyTables = "2.4.0, 3.1.2" ProgressMeter = "1.10.4" Reexport = "1.2.2" Static = "1.3.1" Statistics = "1" StyledStrings = "1.0.3, 1.11.0" +YAML = "0.4.16" Zygote = "0.7.7" julia = "1.10" diff --git a/src/EasyHybrid.jl b/src/EasyHybrid.jl index 985c6439..3cd83a53 100644 --- a/src/EasyHybrid.jl +++ b/src/EasyHybrid.jl @@ -26,6 +26,7 @@ using MLJ: partition using MLUtils: MLUtils, DataLoader, kfolds, numobs, rpad, splitobs using NCDatasets: NCDatasets, NCDataset, close, name using OptimizationOptimisers: OptimizationOptimisers, AdamW, Adam, Optimisers +using OrderedCollections: OrderedDict using PrettyTables: PrettyTables using Printf: Printf, @sprintf using ProgressMeter: ProgressMeter, Progress, next! @@ -33,6 +34,7 @@ using Random: Random, AbstractRNG, randperm, randstring using Reexport: @reexport using Statistics: Statistics, mean, cor, quantile, var using StyledStrings: StyledStrings, @styled_str +using YAML: load_file, write_file using Zygote: Zygote using Static: False, True @@ -66,5 +68,6 @@ include("utils/helpers_for_HybridModel.jl") include("utils/helpers_data_loading.jl") include("tune.jl") include("utils/helpers_cross_validation.jl") +include("utils/config_yaml.jl") end diff --git a/src/train.jl b/src/train.jl index 2f3b2631..2c38cf91 100644 --- a/src/train.jl +++ b/src/train.jl @@ -354,6 +354,14 @@ function train( train_diffs = !isempty(set_diff) ? NamedTuple{Tuple(set_diff)}([getproperty(ŷ_train, e) for e in set_diff]) : nothing val_diffs = !isempty(set_diff) ? NamedTuple{Tuple(set_diff)}([getproperty(ŷ_val, e) for e in set_diff]) : nothing + # collect all training arguments into a NamedTuple for easier saving and logging + train_args = (; nepochs, batchsize, opt, patience, autodiff_backend, return_gradients, array_type, training_loss, loss_types, extra_loss, agg, train_from, random_seed, file_name, hybrid_name, return_model, monitor_names, folder_to_save, plotting, show_progress, yscale) + + # get config settings and save to yaml file + get_config_settings = get_full_config(hybridModel, train_args) + path_yaml = joinpath(tmp_folder, "config_settings.yaml") + save_hybrid_config(get_config_settings, path_yaml) + # TODO: save/output metrics return TrainResults( train_history, @@ -366,7 +374,7 @@ function train( ps, st, best_epoch, - best_agg_loss + best_agg_loss, ) end diff --git a/src/utils/config_yaml.jl b/src/utils/config_yaml.jl new file mode 100644 index 00000000..c64a6a58 --- /dev/null +++ b/src/utils/config_yaml.jl @@ -0,0 +1,33 @@ +export load_hybrid_config, save_hybrid_config +export get_hybrid_config, get_train_config, get_full_config + +function load_hybrid_config(path::String; dicttype = OrderedDict{String, Any}) + return load_file(path; dicttype) +end + +function save_hybrid_config(config::OrderedDict, path::String) + return write_file(path, config) +end + +function get_hybrid_config(hm::LuxCore.AbstractLuxContainerLayer) + hm_config = OrderedDict{String, Any}() + for field in fieldnames(typeof(hm)) + hm_config[string(field)] = getfield(hm, field) + end + return hm_config +end + +function get_train_config(train_args::NamedTuple) + train_config = OrderedDict{String, Any}() + for field in fieldnames(typeof(train_args)) + train_config[string(field)] = getfield(train_args, field) + end + return train_config +end + +function get_full_config(hm::LuxCore.AbstractLuxContainerLayer, train_args::NamedTuple) + full_config = OrderedDict{String, Any}() + full_config["hybrid_model"] = get_hybrid_config(hm) + full_config["train_args"] = get_train_config(train_args) + return full_config +end