@@ -11,13 +11,12 @@ include("variable_indexed_structs.jl")
1111
1212Get the ordered output variables from the input-output map.
1313"""
14- function get_ordered_output_variables (input_output_map:: Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}} )
14+ function get_ordered_output_variables (
15+ input_output_map:: Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}} ,
16+ )
1517 return reduce (
16- vcat,
17- [
18- reduce (vcat, values (iomap))
19- for iomap in input_output_map
20- ]
18+ vcat,
19+ [reduce (vcat, values (iomap)) for iomap in input_output_map],
2120 )
2221end
2322
2625
2726Get the input indices from the input-output map.
2827"""
29- function get_input_indices (input_output_map:: Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}} )
28+ function get_input_indices (
29+ input_output_map:: Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}} ,
30+ )
3031 return unique (
31- reduce (
32- vcat,
33- [
34- reduce (vcat, keys (iomap))
35- for iomap in input_output_map
36- ]
37- )
32+ reduce (vcat, [reduce (vcat, keys (iomap)) for iomap in input_output_map]),
3833 )
3934end
4035
4338
4439Get the maximum input index from the input-output maps.
4540"""
46- function get_max_input_index (input_output_map:: Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}} )
41+ function get_max_input_index (
42+ input_output_map:: Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}} ,
43+ )
4744 return maximum (get_input_indices (input_output_map))
4845end
4946
@@ -82,14 +79,20 @@ julia> pred_model = PredictiveModel(
8279"""
8380struct PredictiveModel
8481 networks:: Union{Vector{<:Flux.Chain},Vector{<:Flux.Dense}}
85- input_output_map:: Union {Vector{<: Dict{Vector{Int},<:Vector{<:Forecast}} },Nothing}
82+ input_output_map:: Union {
83+ Vector{<: Dict{Vector{Int},<:Vector{<:Forecast}} },
84+ Nothing,
85+ }
8686 output_variables:: Union{Vector{<:Forecast},Nothing}
8787 input_size:: Int
8888 output_size:: Int
8989
9090 function PredictiveModel (
9191 networks:: Union{Vector{<:Flux.Chain},Vector{<:Flux.Dense}} ,
92- input_output_map:: Union {Vector{<: Dict{Vector{Int},<:Vector{<:Forecast}} },Nothing},
92+ input_output_map:: Union {
93+ Vector{<: Dict{Vector{Int},<:Vector{<:Forecast}} },
94+ Nothing,
95+ },
9396 output_variables:: Union{Vector{<:Forecast},Nothing} ,
9497 input_size:: Int ,
9598 output_size:: Int ,
@@ -99,7 +102,7 @@ struct PredictiveModel
99102 input_output_map,
100103 output_variables,
101104 input_size,
102- output_size
105+ output_size,
103106 )
104107 end
105108end
@@ -112,7 +115,10 @@ from Flux models and input/output map.
112115"""
113116function PredictiveModel (
114117 networks:: Union{Vector{<:Flux.Chain},Vector{<:Flux.Dense}} ,
115- input_output_map:: Union {Vector{<: Dict{Vector{Int},<:Vector{<:Forecast}} },Nothing},
118+ input_output_map:: Union {
119+ Vector{<: Dict{Vector{Int},<:Vector{<:Forecast}} },
120+ Nothing,
121+ },
116122)
117123 output_variables = get_ordered_output_variables (input_output_map)
118124 input_size = get_max_input_index (input_output_map)
@@ -122,7 +128,7 @@ function PredictiveModel(
122128 input_output_map,
123129 output_variables,
124130 input_size,
125- output_size
131+ output_size,
126132 )
127133end
128134
@@ -231,7 +237,7 @@ specifying that only the networks field is trainable.
231237Flux. trainable (model:: PredictiveModel ) = (networks = model. networks,)
232238
233239# Tells Flux to only look at the 'network' field when setting up or traversing
234- @ Functors. functor PredictiveModel (networks,)
240+ Functors. @ functor PredictiveModel (networks,)
235241
236242"""
237243 (model::PredictiveModel)(X::AbstractMatrix, ignore_index::Bool = false)
280286"""
281287 (model::PredictiveModel)(x::AbstractVector, ignore_index::Bool = false)
282288
283- Predict the output of the model for a given input vector.
289+ Predict the output of the model for a given input vector.
284290If the model has no input-output map, the network is applied directly to the input.
285291If ignore_index is true, the output variables are not returned.
286292"""
@@ -375,6 +381,10 @@ function apply_gradient!(
375381 X:: Matrix{<:Real} ,
376382 opt_state,
377383)
378- return apply_gradient! (model, dCdy[model. output_variables]. data, X, opt_state)
384+ return apply_gradient! (
385+ model,
386+ dCdy[model. output_variables]. data,
387+ X,
388+ opt_state,
389+ )
379390end
380-
0 commit comments