1- export Global, Varying, Fixed, PartitionedFunction, HybridModel
1+ export Global, Varying, Fixed, PartitionedFunction, HybridModel, setbounds, setup
22export HybridSymbolic, SymbolTypes
33
44abstract type SymbolTypes end
@@ -39,9 +39,15 @@ struct PartitionedFunction{F,O,A1,A2,A3,A4,V} <: HybridSymbolic
3939 end
4040end
4141
42- struct HybridModel <: HybridSymbolic
42+ @proto struct HybridModel{T} <: HybridSymbolic
4343 nn:: Lux.Chain
4444 func:: PartitionedFunction
45+ p_min:: T
46+ p_max:: T
47+ end
48+
49+ function HybridModel (nn:: Lux.Chain , func:: PartitionedFunction )
50+ return HybridModel (nn, func, nothing , nothing )
4551end
4652# TODO : This needs to be more general. i.e. ŷ = NN(α * NN(x) + β).
4753
@@ -60,4 +66,24 @@ function (m::HybridModel)(X::Vector{Float32}, params, st)
6066 out_NN = m. nn (X, ps, st)[1 ]
6167 out = m. func. opt_func (tuple ([[out_NN[1 ]] for i = 1 : n_varargs]. .. ), globals)
6268 return out[1 ]
63- end
69+ end
70+
71+ # Assumes that the last layer has sigmoid activation function
72+ function setbounds (m:: HybridModel , bounds:: Dict{Symbol, Tuple{T,T}} ) where {T}
73+ n_args = length (m. func. varying_args)
74+ p_min = zeros (Float32, n_args)
75+ p_max = zeros (Float32, n_args)
76+ for (i,arg) in enumerate (Symbol .(m. func. varying_args))
77+ @assert arg in keys (bounds)
78+ p_min[i] = bounds[arg][1 ]
79+ p_max[i] = bounds[arg][2 ]
80+ end
81+ p_range = p_max .- p_min
82+ wf = WrappedFunction ((x) -> x .* (p_range) .+ p_min)
83+ new_nn = Chain (m. nn, wf)
84+ return HybridModel (new_nn, m. func, p_min, p_max)
85+ end
86+
87+ function setup (rng:: AbstractRNG , m:: HybridModel )
88+ return Lux. setup (rng, m. nn)
89+ end
0 commit comments