@@ -50,24 +50,70 @@ function HybridModel(nn::Lux.Chain, func::PartitionedFunction)
5050 return HybridModel (nn, func, nothing , nothing )
5151end
5252# TODO : This needs to be more general. i.e. ŷ = NN(α * NN(x) + β).
53+ #
54+ function (m:: HybridModel )(X:: VecOrMat{Float32} , params, st; forcings = nothing , return_parameters:: Val{T} = Val (false )) where {T}
55+ if T
56+ return runHybridModelAll (m, X, params, st; forcings = forcings, return_parameters = return_parameters)
57+ else
58+ return runHybridModelSimple (m, X, params, st; forcings = forcings)
59+ end
60+ end
5361
54- function (m:: HybridModel )( X:: Matrix{Float32} , params, st)
62+ function runHybridModelSimple (m:: HybridModel , X:: Matrix{Float32} , params, st; forcings )
5563 ps = params. nn
5664 globals = params. globals
5765 n_varargs = length (m. func. varying_args)
5866 out_NN = m. nn (X, ps, st)[1 ]
59- out = m. func. opt_func (tuple ([out_NN[i,:] for i = 1 : n_varargs]. .. ), globals)
67+ out = m. func. opt_func (tuple ([out_NN[i,:] for i = 1 : n_varargs]. .. ), globals; forcings = forcings )
6068 return out
6169end
62- function (m:: HybridModel )( X:: Vector{Float32} , params, st)
70+ function runHybridModelSimple (m:: HybridModel , X:: Vector{Float32} , params, st; forcings )
6371 ps = params. nn
6472 globals = params. globals
6573 n_varargs = length (m. func. varying_args)
6674 out_NN = m. nn (X, ps, st)[1 ]
67- out = m. func. opt_func (tuple ([[out_NN[1 ]] for i = 1 : n_varargs]. .. ), globals)
75+ out = m. func. opt_func (tuple ([[out_NN[1 ]] for i = 1 : n_varargs]. .. ), globals; forcings = forcings )
6876 return out[1 ]
6977end
7078
79+ function runHybridModelAll (m:: HybridModel , X:: Vector{Float32} , params, st; return_parameters:: Val{true} , forcings)
80+ ps = params. nn
81+ globals = params. globals
82+ n_varargs = length (m. func. varying_args)
83+ out_NN = m. nn (X, ps, st)[1 ]
84+ y = m. func. opt_func (tuple ([out_NN[i,:] for i = 1 : n_varargs]. .. ), globals; forcings = forcings)
85+ D = Dict {Symbol, Float32} ()
86+ D[:out ] = y[1 ]
87+ for (i, param) in enumerate (m. func. varying_args)
88+ D[Symbol (param)] = out_NN[i,1 ]
89+ end
90+ for (i, param) in enumerate (m. func. global_args)
91+ D[Symbol (param)] = globals[i]
92+ end
93+ for (i, param) in enumerate (m. func. fixed_args)
94+ D[Symbol (param)] = m. func. fixed_vals[i]
95+ end
96+ return D
97+ end
98+ function runHybridModelAll (m:: HybridModel , X:: Matrix{Float32} , params, st; return_parameters:: Val{true} , forcings)
99+ ps = params. nn
100+ globals = params. globals
101+ n_varargs = length (m. func. varying_args)
102+ out_NN = m. nn (X, ps, st)[1 ]
103+ y = m. func. opt_func (tuple ([[out_NN[1 ]] for i = 1 : n_varargs]. .. ), globals; forcings = forcings)
104+ D = Dict {Symbol, Vector{Float32}} ()
105+ D[:out ] = y[1 ]
106+ for (i, param) in enumerate (m. func. varying_args)
107+ D[Symbol (param)] = out_NN[i,:]
108+ end
109+ for (i, param) in enumerate (m. func. global_args)
110+ D[Symbol (param)] = ones (Float32, size (X,1 )) .* globals[i]
111+ end
112+ for (i, param) in enumerate (m. func. fixed_args)
113+ D[Symbol (param)] = ones (Float32, size (X,1 )) .* m. func. fixed_vals[i]
114+ end
115+ return D
116+ end
71117# Assumes that the last layer has sigmoid activation function
72118function setbounds (m:: HybridModel , bounds:: Dict{Symbol, Tuple{T,T}} ) where {T}
73119 n_args = length (m. func. varying_args)
0 commit comments