Skip to content

Commit b6ec807

Browse files
authored
Merge pull request #2 from EarthyScience/bounds
bounds
2 parents 7e031ea + 7601385 commit b6ec807

File tree

5 files changed

+65
-4
lines changed

5 files changed

+65
-4
lines changed

src/HybridSymbolics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using ADTypes
99
using ForwardDiff
1010
using ProgressMeter
1111
using ProtoStructs
12+
using Random
1213

1314
include("hybridModel.jl")
1415
include("macroHybrid.jl")

src/hybridModel.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export Global, Varying, Fixed, PartitionedFunction, HybridModel
1+
export Global, Varying, Fixed, PartitionedFunction, HybridModel, setbounds, setup
22
export HybridSymbolic, SymbolTypes
33

44
abstract type SymbolTypes end
@@ -39,9 +39,15 @@ struct PartitionedFunction{F,O,A1,A2,A3,A4,V} <: HybridSymbolic
3939
end
4040
end
4141

42-
struct HybridModel <: HybridSymbolic
42+
@proto struct HybridModel{T} <: HybridSymbolic
4343
nn::Lux.Chain
4444
func::PartitionedFunction
45+
p_min::T
46+
p_max::T
47+
end
48+
49+
function HybridModel(nn::Lux.Chain, func::PartitionedFunction)
50+
return HybridModel(nn, func, nothing, nothing)
4551
end
4652
# TODO: This needs to be more general. i.e. ŷ = NN(α * NN(x) + β).
4753

@@ -60,4 +66,24 @@ function (m::HybridModel)(X::Vector{Float32}, params, st)
6066
out_NN = m.nn(X, ps, st)[1]
6167
out = m.func.opt_func(tuple([[out_NN[1]] for i = 1:n_varargs]...), globals)
6268
return out[1]
63-
end
69+
end
70+
71+
# Assumes that the last layer has sigmoid activation function
72+
function setbounds(m::HybridModel, bounds::Dict{Symbol, Tuple{T,T}}) where {T}
73+
n_args = length(m.func.varying_args)
74+
p_min = zeros(Float32, n_args)
75+
p_max = zeros(Float32, n_args)
76+
for (i,arg) in enumerate(Symbol.(m.func.varying_args))
77+
@assert arg in keys(bounds)
78+
p_min[i] = bounds[arg][1]
79+
p_max[i] = bounds[arg][2]
80+
end
81+
p_range = p_max .- p_min
82+
wf = WrappedFunction((x) -> x .* (p_range) .+ p_min)
83+
new_nn = Chain(m.nn, wf)
84+
return HybridModel(new_nn, m.func, p_min, p_max)
85+
end
86+
87+
function setup(rng::AbstractRNG, m::HybridModel)
88+
return Lux.setup(rng, m.nn)
89+
end

test/core.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,38 @@ function test_hybridmodel()
4747
@test model(rand(Float32, 5,5), model_params, st) isa Vector{Float32}
4848
end
4949

50+
function test_bounds()
51+
local input_size = 5
52+
local α, β, γ, δ
53+
@syms α::Real β::Real γ::Real δ::Real
54+
structured = @hybrid function testfunc::Varying, β::Varying, γ::Fixed=1.0, δ::Global)
55+
return (exp.(α) .- β)./.* δ)
56+
end
57+
NN = Chain(
58+
Dense(input_size => 4, sigmoid_fast),
59+
Dense(4 => 2, sigmoid_fast)
60+
)
61+
NN = f32(NN)
62+
rng = MersenneTwister()
63+
model = HybridModel(
64+
NN,
65+
structured
66+
)
67+
model = setbounds(model, Dict( => (-1.0f0, 1.0f0), => (-1.0f0, 1.0f0)))
68+
@test model isa HybridModel
69+
ps, st = setup(rng, model)
70+
globals = [1.2f0]
71+
model_params = (nn = ps, globals = globals)
72+
@test model(rand(Float32, 5), model_params, st) isa Float32
73+
@test model(rand(Float32, 5,5), model_params, st) isa Vector{Float32}
74+
@test all(model.p_min .== -1.0f0)
75+
@test all(model.p_max .== 1.0f0)
76+
@testset "Testing Bounds: input $item" for item in [rand(Float32, 5) for _ in 1:10]
77+
output_params = model.nn(item, ps, st)[1]
78+
@test all(output_params .>= -1.0f0) && all(output_params .<= 1.0f0)
79+
end
80+
end
81+
5082
function test_gradcalc()
5183
local input_size = 5
5284
local α, β, γ, δ

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ include("./core.jl")
55
@testset "HybridSymbolics.jl" begin
66
@testset test_structuredfunc()
77
@testset test_hybridmodel()
8+
@testset test_bounds()
89
@testset test_gradcalc()
910
end

test/sample.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ NN = Chain(
2020

2121
NN = f32(NN)
2222
rng = MersenneTwister()
23-
ps, st = Lux.setup(rng, NN)
2423

2524
model = HybridModel(
2625
NN,
2726
structured
2827
)
28+
model = setbounds(model, Dict( => (-1.0f0, 1.0f0), => (-1.0f0, 1.0f0)))
2929

30+
ps, st = setup(rng, model)
3031
globals = [1.2f0]
3132
model_params = (nn = ps, globals = globals)
3233

0 commit comments

Comments
 (0)