Skip to content

Commit c079eb9

Browse files
authored
Merge pull request #3 from EarthyScience/forcings
forcings
2 parents b6ec807 + d04c10c commit c079eb9

File tree

4 files changed

+90
-8
lines changed

4 files changed

+90
-8
lines changed

src/hybridModel.jl

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,70 @@ function HybridModel(nn::Lux.Chain, func::PartitionedFunction)
5050
return HybridModel(nn, func, nothing, nothing)
5151
end
5252
# TODO: This needs to be more general. i.e. ŷ = NN(α * NN(x) + β).
53+
#
54+
function (m::HybridModel)(X::VecOrMat{Float32}, params, st; forcings = nothing, return_parameters::Val{T} = Val(false)) where {T}
55+
if T
56+
return runHybridModelAll(m, X, params, st; forcings = forcings, return_parameters = return_parameters)
57+
else
58+
return runHybridModelSimple(m, X, params, st; forcings = forcings)
59+
end
60+
end
5361

54-
function (m::HybridModel)(X::Matrix{Float32}, params, st)
62+
function runHybridModelSimple(m::HybridModel, X::Matrix{Float32}, params, st; forcings)
5563
ps = params.nn
5664
globals = params.globals
5765
n_varargs = length(m.func.varying_args)
5866
out_NN = m.nn(X, ps, st)[1]
59-
out = m.func.opt_func(tuple([out_NN[i,:] for i = 1:n_varargs]...), globals)
67+
out = m.func.opt_func(tuple([out_NN[i,:] for i = 1:n_varargs]...), globals; forcings = forcings)
6068
return out
6169
end
62-
function (m::HybridModel)(X::Vector{Float32}, params, st)
70+
function runHybridModelSimple(m::HybridModel, X::Vector{Float32}, params, st; forcings)
6371
ps = params.nn
6472
globals = params.globals
6573
n_varargs = length(m.func.varying_args)
6674
out_NN = m.nn(X, ps, st)[1]
67-
out = m.func.opt_func(tuple([[out_NN[1]] for i = 1:n_varargs]...), globals)
75+
out = m.func.opt_func(tuple([[out_NN[1]] for i = 1:n_varargs]...), globals; forcings = forcings)
6876
return out[1]
6977
end
7078

79+
function runHybridModelAll(m::HybridModel, X::Vector{Float32}, params, st; return_parameters::Val{true}, forcings)
80+
ps = params.nn
81+
globals = params.globals
82+
n_varargs = length(m.func.varying_args)
83+
out_NN = m.nn(X, ps, st)[1]
84+
y = m.func.opt_func(tuple([out_NN[i,:] for i = 1:n_varargs]...), globals; forcings = forcings)
85+
D = Dict{Symbol, Float32}()
86+
D[:out] = y[1]
87+
for (i, param) in enumerate(m.func.varying_args)
88+
D[Symbol(param)] = out_NN[i,1]
89+
end
90+
for (i, param) in enumerate(m.func.global_args)
91+
D[Symbol(param)] = globals[i]
92+
end
93+
for (i, param) in enumerate(m.func.fixed_args)
94+
D[Symbol(param)] = m.func.fixed_vals[i]
95+
end
96+
return D
97+
end
98+
function runHybridModelAll(m::HybridModel, X::Matrix{Float32}, params, st; return_parameters::Val{true}, forcings)
99+
ps = params.nn
100+
globals = params.globals
101+
n_varargs = length(m.func.varying_args)
102+
out_NN = m.nn(X, ps, st)[1]
103+
y = m.func.opt_func(tuple([[out_NN[1]] for i = 1:n_varargs]...), globals; forcings = forcings)
104+
D = Dict{Symbol, Vector{Float32}}()
105+
D[:out] = y[1]
106+
for (i, param) in enumerate(m.func.varying_args)
107+
D[Symbol(param)] = out_NN[i,:]
108+
end
109+
for (i, param) in enumerate(m.func.global_args)
110+
D[Symbol(param)] = ones(Float32, size(X,1)) .* globals[i]
111+
end
112+
for (i, param) in enumerate(m.func.fixed_args)
113+
D[Symbol(param)] = ones(Float32, size(X,1)) .* m.func.fixed_vals[i]
114+
end
115+
return D
116+
end
71117
# Assumes that the last layer has sigmoid activation function
72118
function setbounds(m::HybridModel, bounds::Dict{Symbol, Tuple{T,T}}) where {T}
73119
n_args = length(m.func.varying_args)

src/macroHybrid.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ function optimize_func(fun, global_args, fixed_args, varying_args, fixed_vals)
4848
func_expr = copy(fun)
4949
func_body = func_expr.args[2]
5050
func_args = Expr(:tuple)
51+
push!(func_args.args, :($(Expr(:parameters, :forcings))))
5152
varying_tuple = Expr(:(::), :varying_params, Expr(:curly, :Tuple, [:(Vector{Float32}) for _ in varying_args]...))
5253
push!(func_args.args, varying_tuple)
5354
global_array = Expr(:(::), :global_params, :(Vector{Float32}))
@@ -83,4 +84,4 @@ macro hybrid(fun)
8384
return quote
8485
PartitionedFunction($(esc(fun)), $(esc(args_syms)), $(esc(global_args)), $(esc(fixed_args)), $(esc(varying_args)), $(fixed_vals), $(esc(opt_func)))
8586
end
86-
end
87+
end

test/core.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function test_structuredfunc()
1919
@test length(structured.fixed_vals) == 1
2020
@test structured.fixed_vals[1] == 1.0f0
2121
opt_func = structured.opt_func
22-
@test opt_func(([1.0f0], [1.0f0]), [1.0f0])[1] == exp(1.0f0) - 1.0f0
22+
@test opt_func(([1.0f0], [1.0f0]), [1.0f0]; forcings = nothing)[1] == exp(1.0f0) - 1.0f0
2323
end
2424

2525
function test_hybridmodel()
@@ -45,6 +45,15 @@ function test_hybridmodel()
4545
model_params = (nn = ps, globals = globals)
4646
@test model(rand(Float32, 5), model_params, st) isa Float32
4747
@test model(rand(Float32, 5,5), model_params, st) isa Vector{Float32}
48+
@testset "Total Output" begin
49+
@test model(rand(Float32, 5), model_params, st; return_parameters = Val(true)) isa Dict
50+
output = model(rand(Float32, 5), model_params, st; return_parameters = Val(true))
51+
@testset "Arguments in output" for arg in [, , , ]
52+
@test arg in keys(output)
53+
end
54+
@test :out in keys(output)
55+
@test output[] == 1.0f0
56+
end
4857
end
4958

5059
function test_bounds()
@@ -112,3 +121,28 @@ function test_gradcalc()
112121
@test new_params != model_params
113122
end
114123

124+
function test_forcings()
125+
local input_size = 5
126+
local α, β, γ, δ
127+
@syms α::Real β::Real γ::Real δ::Real
128+
structured = @hybrid function testfunc::Varying, β::Varying, γ::Fixed=1.0, δ::Global; forcings)
129+
T = forcings[1,:]
130+
return (exp.(α) .- β)./.* δ) .+ T
131+
end
132+
NN = Chain(
133+
Dense(input_size => 4, sigmoid_fast),
134+
Dense(4 => 2, sigmoid_fast)
135+
)
136+
NN = f32(NN)
137+
rng = MersenneTwister()
138+
model = HybridModel(
139+
NN,
140+
structured
141+
)
142+
ps, st = setup(rng, model)
143+
globals = [1.2f0]
144+
model_params = (nn = ps, globals = globals)
145+
@test model(rand(Float32, 5), model_params, st; forcings = ones(Float32, 1,1)) isa Float32
146+
@test model(rand(Float32, 5,5), model_params, st; forcings = ones(Float32, 1,5)) isa Vector{Float32}
147+
@test all(model(rand(Float32, 5,5), model_params, st; forcings = ones(Float32, 1,5)) .> 1.0f0)
148+
end

test/sample.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ const input_size = 5
44
local α, β, γ, δ
55
@syms α::Real β::Real γ::Real δ::Real
66

7-
structured = @hybrid function testfunc::Varying, β::Varying, γ::Fixed=1.0, δ::Global)
8-
return (exp.(α) .- β)./.* δ)
7+
structured = @hybrid function testfunc::Varying, β::Varying, γ::Fixed=1.0, δ::Global; forcings)
8+
T = forcings[1,:]
9+
return (exp.(α) .- β)./.* δ) .+ T
910
end
1011

1112
# TODO, define first symbolic function and then apply macro?

0 commit comments

Comments
 (0)