Skip to content

Commit 7cd7c9e

Browse files
committed
replace transformations by Exp
because exlementwise(exp) failed on AD on GPU test on GPU
1 parent 56adfbf commit 7cd7c9e

File tree

6 files changed

+40
-27
lines changed

6 files changed

+40
-27
lines changed

src/DoubleMM/f_doubleMM.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ const θall = vcat(θP, θM)
66

77
const θP_nor0 = θP[(:K2,)]
88

9-
const transP = elementwise(exp)
10-
const transM = elementwise(exp)
9+
# const transP = elementwise(exp)
10+
# const transM = elementwise(exp)
1111

12-
const transMS = Stacked(elementwise(identity), elementwise(exp))
12+
# const transMS = Stacked(elementwise(identity), elementwise(exp))
1313

1414
const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM)))
1515

@@ -114,15 +114,17 @@ end
114114

115115

116116
function HVI.get_hybridproblem_transforms(prob::DoubleMMCase; scenario::NTuple = ())
117+
_θP, _θM = get_hybridproblem_par_templates(prob; scenario)
117118
if (:stackedMS scenario)
118-
return ((; transP, transM = transMS))
119+
return (; transP = Stacked((HVI.Exp(),),(1:length(_θP),)),
120+
transM = Stacked((identity,HVI.Exp(),),(1:1, 2:length(_θM),)))
119121
elseif (:transIdent scenario)
120122
# identity transformations, should AD on GPU
121-
_θP, _θM = get_hybridproblem_par_templates(prob; scenario)
122123
return (; transP = Stacked((identity,),(1:length(_θP),)),
123124
transM = Stacked((identity,),(1:length(_θM),)))
124125
end
125-
(; transP, transM)
126+
(; transP = Stacked((HVI.Exp(),),(1:length(_θP),)),
127+
transM = Stacked((HVI.Exp(),),(1:length(_θM),)))
126128
end
127129

128130
# function HVI.get_hybridproblem_sizes(::DoubleMMCase; scenario = ())

src/HybridVariationalInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using StaticArrays: StaticArrays as SA
2121
using Functors
2222

2323
#export Exp
24-
include("bijectors_exp.jl")
24+
include("bijectors_utils.jl")
2525

2626
export ComponentArrayInterpreter, flatten1, get_concrete, get_positions
2727
include("ComponentArrayInterpreter.jl")

src/bijectors_exp.jl

Lines changed: 0 additions & 17 deletions
This file was deleted.

src/bijectors_utils.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
struct Exp <: Bijector
2+
end
3+
4+
#Functors.@functor Exp
5+
Bijectors.transform(b::Exp, x) = exp.(x) # note the broadcast
6+
Bijectors.transform(ib::Inverse{<:Exp}, y) = log.(y)
7+
8+
# `logabsdetjac`
9+
Bijectors.logabsdetjac(b::Exp, x) = sum(x)
10+
11+
`with_logabsdet_jacobian`
12+
function Bijectors.with_logabsdet_jacobian(b::Exp, x)
13+
return exp.(x), sum(x)
14+
end
15+
# function Bijectors.with_logabsdet_jacobian(ib::Inverse{<:Exp}, y)
16+
# x = transform(ib, y)
17+
# return x, -logabsdetjac(inverse(ib), x)
18+
# end
19+
20+
21+
Bijectors.is_monotonically_increasing(::Exp) = true

test/test_HybridProblem.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ test_with_flux = (scenario) -> begin
199199
scenf = (scenario..., :use_Flux, :use_gpu, :omit_r0)
200200
rng = StableRNG(111)
201201
# here using DoubleMMCase() directly rather than construct_problem
202+
#(;transP, transM) = get_hybridproblem_transforms(DoubleMM.DoubleMMCase(); scenario = scenf)
202203
prob = probg = HybridProblem(DoubleMM.DoubleMMCase(); scenario = scenf);
203204
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
204205
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario = scenf)
@@ -216,13 +217,15 @@ test_with_flux = (scenario) -> begin
216217
maxiters = 37,
217218
);
218219
@test cdev(ϕ.unc.ρsM)[1] > 0
220+
@test probo.ϕunc == cdev(ϕ.unc)
219221
test_correlation = () -> begin
220-
n_epoch = 100 # requires
222+
n_epoch = 20 # requires
221223
(; ϕ, θP, resopt, probo) = solve(prob, solver; scenario = scenf,
222224
maxiters = n_batches_in_epoch * n_epoch,
223225
callback = callback_loss(n_batches_in_epoch*5)
224226
);
225227
@test cdev(ϕ.unc.ρsM)[1] > 0
228+
@test probo.ϕunc == cdev(ϕ.unc)
226229
# predict using problem and its associated dataloader
227230
(; θ, y, entropy_ζ) = predict_gf(rng, probo; scenario = scenf, n_sample_pred = 200);
228231
mean_θ = CA.ComponentVector(mean(CA.getdata(θ); dims = 2)[:, 1], CA.getaxes(θ[:, 1])[1])
@@ -240,8 +243,7 @@ test_with_flux = (scenario) -> begin
240243
scenf = (scenario..., :use_Flux, :use_gpu, :omit_r0, :f_on_gpu)
241244
rng = StableRNG(111)
242245
probg = HybridProblem(DoubleMM.DoubleMMCase(); scenario = scenf);
243-
# TODO replace exp by Exp Transformer
244-
prob = CP.update(probg, transM = identity, transP = identity);
246+
#prob = CP.update(probg, transM = identity, transP = identity);
245247
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
246248
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario = scenf)
247249
n_batches_in_epoch = n_site ÷ n_batch

test/test_bijectors_utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ b2 = elementwise(exp)
2222
b2s = Stacked((b2,b2),(1:3,4:4))
2323
b3 = HybridVariationalInference.Exp()
2424
b3s = Stacked((b3,b3), (1:3,4:4))
25+
#b3s = Stacked((b3,),(1:4,))
2526

2627

2728
y = trans(x, b2)
@@ -35,6 +36,10 @@ dy = Zygote.gradient(x -> trans(x,b2), x)
3536
end;
3637

3738
@testset "Exp" begin
39+
y1 = b3(x)
40+
y2 = b3s(x)
41+
@test all(inverse(b3)(y2) .≈ x)
42+
@test all(inverse(b3s)(y2) .≈ x)
3843
ye = trans(x, b3)
3944
dye = Zygote.gradient(x -> trans(x,b3), x)
4045
@test ye == y

0 commit comments

Comments
 (0)