Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The OrderedCollections dependency has been added. Please ensure that this package is necessary for the YAML functionality and that its inclusion does not introduce any unnecessary overhead or conflicts with existing dependencies.

PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Expand All @@ -31,6 +32,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StyledStrings = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The YAML dependency has been added. This is a core dependency for the new configuration features. Please confirm that the chosen version 0.4.16 is compatible with other packages and the Julia version.

Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
Expand Down Expand Up @@ -60,12 +62,14 @@ Makie = "0.22, 0.23, 0.24"
NCDatasets = "0.14.8"
NamedDims = "1.2.3"
OptimizationOptimisers = "0.3.7"
OrderedCollections = "1.8.1"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The OrderedCollections compatibility entry has been added. Ensure that the specified version range 1.8.1 is accurate and covers the required functionality without introducing breaking changes.

PrettyTables = "2.4.0, 3.1.2"
ProgressMeter = "1.10.4"
Reexport = "1.2.2"
Static = "1.3.1"
Statistics = "1"
StyledStrings = "1.0.3, 1.11.0"
YAML = "0.4.16"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The YAML compatibility entry has been added. Please verify that the version 0.4.16 is the correct and stable version to use for this project.

Zygote = "0.7.7"
julia = "1.10"

Expand Down
3 changes: 3 additions & 0 deletions src/EasyHybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ using MLJ: partition
using MLUtils: MLUtils, DataLoader, kfolds, numobs, rpad, splitobs
using NCDatasets: NCDatasets, NCDataset, close, name
using OptimizationOptimisers: OptimizationOptimisers, AdamW, Adam, Optimisers
using OrderedCollections: OrderedDict
using PrettyTables: PrettyTables
using Printf: Printf, @sprintf
using ProgressMeter: ProgressMeter, Progress, next!
using Random: Random, AbstractRNG, randperm, randstring
using Reexport: @reexport
using Statistics: Statistics, mean, cor, quantile, var
using StyledStrings: StyledStrings, @styled_str
using YAML: load_file, write_file
using Zygote: Zygote
using Static: False, True

Expand Down Expand Up @@ -66,5 +68,6 @@ include("utils/helpers_for_HybridModel.jl")
include("utils/helpers_data_loading.jl")
include("tune.jl")
include("utils/helpers_cross_validation.jl")
include("utils/config_yaml.jl")

end
4 changes: 3 additions & 1 deletion src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ struct TrainResults
st
best_epoch
best_loss
train_args

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding train_args to TrainResults is a good improvement for reproducibility, as it allows all training parameters to be stored with the results. This is a high-value addition for debugging and understanding past training runs.

end

"""
Expand Down Expand Up @@ -366,7 +367,8 @@ function train(
ps,
st,
best_epoch,
best_agg_loss
best_agg_loss,
(; nepochs, batchsize, opt, patience, autodiff_backend, return_gradients, array_type, training_loss, loss_types, extra_loss, agg, train_from, random_seed, file_name, hybrid_name, return_model, monitor_names, folder_to_save, plotting, show_progress, yscale)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The train_args tuple is being constructed with a comprehensive list of training parameters. This ensures that all relevant arguments are captured and stored in TrainResults, which is excellent for traceability and reproducibility. Consider if any other kwargs passed to train should also be explicitly captured here for completeness.

)
end

Expand Down
24 changes: 24 additions & 0 deletions src/utils/config_yaml.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
export load_hybrid_config, save_hybrid_config
function load_hybrid_config(path::String; dicttype=OrderedDict{String,Any})
return load_file(path; dicttype, )
end

function save_hybrid_config(config::Dict, path::String)
return write_file(path, config)
end

function get_hybrid_config(hm::HybridModel)
hm_config = Dict{String,Any}()
for field in fieldnames(typeof(hm))
hm_config[string(field)] = getfield(hm, field)
end
return hm_config
end

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The get_hybrid_config function dynamically extracts all fields from a HybridModel instance into a dictionary. This is a flexible approach for serializing model configurations. Ensure that all fields are suitable for direct serialization to YAML.


function get_train_config(train_args::TrainResults)
train_config = Dict{String,Any}()
for field in fieldnames(typeof(train_args))
train_config[string(field)] = getfield(train_args, field)
end
return train_config
end

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to get_hybrid_config, get_train_config extracts all fields from TrainResults. This is crucial for saving the complete training state and arguments, enhancing reproducibility. Consider if any fields within TrainResults might contain non-serializable objects that would cause issues when writing to YAML.

Loading