diff --git a/src/hybridModel.jl b/src/hybridModel.jl index 15e4f9c..ecaf0e2 100644 --- a/src/hybridModel.jl +++ b/src/hybridModel.jl @@ -50,24 +50,70 @@ function HybridModel(nn::Lux.Chain, func::PartitionedFunction) return HybridModel(nn, func, nothing, nothing) end # TODO: This needs to be more general. i.e. ŷ = NN(α * NN(x) + β). +# +function (m::HybridModel)(X::VecOrMat{Float32}, params, st; forcings = nothing, return_parameters::Val{T} = Val(false)) where {T} + if T + return runHybridModelAll(m, X, params, st; forcings = forcings, return_parameters = return_parameters) + else + return runHybridModelSimple(m, X, params, st; forcings = forcings) + end +end -function (m::HybridModel)(X::Matrix{Float32}, params, st) +function runHybridModelSimple(m::HybridModel, X::Matrix{Float32}, params, st; forcings) ps = params.nn globals = params.globals n_varargs = length(m.func.varying_args) out_NN = m.nn(X, ps, st)[1] - out = m.func.opt_func(tuple([out_NN[i,:] for i = 1:n_varargs]...), globals) + out = m.func.opt_func(tuple([out_NN[i,:] for i = 1:n_varargs]...), globals; forcings = forcings) return out end -function (m::HybridModel)(X::Vector{Float32}, params, st) +function runHybridModelSimple(m::HybridModel, X::Vector{Float32}, params, st; forcings) ps = params.nn globals = params.globals n_varargs = length(m.func.varying_args) out_NN = m.nn(X, ps, st)[1] - out = m.func.opt_func(tuple([[out_NN[1]] for i = 1:n_varargs]...), globals) + out = m.func.opt_func(tuple([[out_NN[1]] for i = 1:n_varargs]...), globals; forcings = forcings) return out[1] end +function runHybridModelAll(m::HybridModel, X::Vector{Float32}, params, st; return_parameters::Val{true}, forcings) + ps = params.nn + globals = params.globals + n_varargs = length(m.func.varying_args) + out_NN = m.nn(X, ps, st)[1] + y = m.func.opt_func(tuple([out_NN[i,:] for i = 1:n_varargs]...), globals; forcings = forcings) + D = Dict{Symbol, Float32}() + D[:out] = y[1] + for (i, param) in enumerate(m.func.varying_args) + D[Symbol(param)] = out_NN[i,1] + end + for (i, param) in enumerate(m.func.global_args) + D[Symbol(param)] = globals[i] + end + for (i, param) in enumerate(m.func.fixed_args) + D[Symbol(param)] = m.func.fixed_vals[i] + end + return D +end +function runHybridModelAll(m::HybridModel, X::Matrix{Float32}, params, st; return_parameters::Val{true}, forcings) + ps = params.nn + globals = params.globals + n_varargs = length(m.func.varying_args) + out_NN = m.nn(X, ps, st)[1] + y = m.func.opt_func(tuple([[out_NN[1]] for i = 1:n_varargs]...), globals; forcings = forcings) + D = Dict{Symbol, Vector{Float32}}() + D[:out] = y[1] + for (i, param) in enumerate(m.func.varying_args) + D[Symbol(param)] = out_NN[i,:] + end + for (i, param) in enumerate(m.func.global_args) + D[Symbol(param)] = ones(Float32, size(X,1)) .* globals[i] + end + for (i, param) in enumerate(m.func.fixed_args) + D[Symbol(param)] = ones(Float32, size(X,1)) .* m.func.fixed_vals[i] + end + return D +end # Assumes that the last layer has sigmoid activation function function setbounds(m::HybridModel, bounds::Dict{Symbol, Tuple{T,T}}) where {T} n_args = length(m.func.varying_args) diff --git a/src/macroHybrid.jl b/src/macroHybrid.jl index 2b1613f..69f445a 100644 --- a/src/macroHybrid.jl +++ b/src/macroHybrid.jl @@ -48,6 +48,7 @@ function optimize_func(fun, global_args, fixed_args, varying_args, fixed_vals) func_expr = copy(fun) func_body = func_expr.args[2] func_args = Expr(:tuple) + push!(func_args.args, :($(Expr(:parameters, :forcings)))) varying_tuple = Expr(:(::), :varying_params, Expr(:curly, :Tuple, [:(Vector{Float32}) for _ in varying_args]...)) push!(func_args.args, varying_tuple) global_array = Expr(:(::), :global_params, :(Vector{Float32})) @@ -83,4 +84,4 @@ macro hybrid(fun) return quote PartitionedFunction($(esc(fun)), $(esc(args_syms)), $(esc(global_args)), $(esc(fixed_args)), $(esc(varying_args)), $(fixed_vals), $(esc(opt_func))) end -end \ No newline at end of file +end diff --git a/test/core.jl b/test/core.jl index 9dfe736..0854246 100644 --- a/test/core.jl +++ b/test/core.jl @@ -19,7 +19,7 @@ function test_structuredfunc() @test length(structured.fixed_vals) == 1 @test structured.fixed_vals[1] == 1.0f0 opt_func = structured.opt_func - @test opt_func(([1.0f0], [1.0f0]), [1.0f0])[1] == exp(1.0f0) - 1.0f0 + @test opt_func(([1.0f0], [1.0f0]), [1.0f0]; forcings = nothing)[1] == exp(1.0f0) - 1.0f0 end function test_hybridmodel() @@ -45,6 +45,15 @@ function test_hybridmodel() model_params = (nn = ps, globals = globals) @test model(rand(Float32, 5), model_params, st) isa Float32 @test model(rand(Float32, 5,5), model_params, st) isa Vector{Float32} + @testset "Total Output" begin + @test model(rand(Float32, 5), model_params, st; return_parameters = Val(true)) isa Dict + output = model(rand(Float32, 5), model_params, st; return_parameters = Val(true)) + @testset "Arguments in output" for arg in [:α, :β, :γ, :δ] + @test arg in keys(output) + end + @test :out in keys(output) + @test output[:γ] == 1.0f0 + end end function test_bounds() @@ -112,3 +121,28 @@ function test_gradcalc() @test new_params != model_params end +function test_forcings() + local input_size = 5 + local α, β, γ, δ + @syms α::Real β::Real γ::Real δ::Real + structured = @hybrid function testfunc(α::Varying, β::Varying, γ::Fixed=1.0, δ::Global; forcings) + T = forcings[1,:] + return (exp.(α) .- β)./(γ .* δ) .+ T + end + NN = Chain( + Dense(input_size => 4, sigmoid_fast), + Dense(4 => 2, sigmoid_fast) + ) + NN = f32(NN) + rng = MersenneTwister() + model = HybridModel( + NN, + structured + ) + ps, st = setup(rng, model) + globals = [1.2f0] + model_params = (nn = ps, globals = globals) + @test model(rand(Float32, 5), model_params, st; forcings = ones(Float32, 1,1)) isa Float32 + @test model(rand(Float32, 5,5), model_params, st; forcings = ones(Float32, 1,5)) isa Vector{Float32} + @test all(model(rand(Float32, 5,5), model_params, st; forcings = ones(Float32, 1,5)) .> 1.0f0) +end diff --git a/test/sample.jl b/test/sample.jl index 61e6f75..4f8299a 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -4,8 +4,9 @@ const input_size = 5 local α, β, γ, δ @syms α::Real β::Real γ::Real δ::Real -structured = @hybrid function testfunc(α::Varying, β::Varying, γ::Fixed=1.0, δ::Global) - return (exp.(α) .- β)./(γ .* δ) +structured = @hybrid function testfunc(α::Varying, β::Varying, γ::Fixed=1.0, δ::Global; forcings) + T = forcings[1,:] + return (exp.(α) .- β)./(γ .* δ) .+ T end # TODO, define first symbolic function and then apply macro?