Skip to content

Commit f9680a6

Browse files
author
Frankie Robertson
committed
Add LaplaceAbilityTracker
1 parent ac8ea3d commit f9680a6

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

experiments/estimator_compare.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ function main()
2929
lh_grid_tracker = GriddedAbilityTracker(lh_ability_est, grid)
3030
prior_grid_tracker = GriddedAbilityTracker(prior_ability_est, grid)
3131
closed_normal_tracker = ClosedFormNormalAbilityTracker(prior_ability_est)
32+
laplace_normal_tracker = LaplaceAbilityTracker(prior_ability_est)
3233
rules = CatRules(
3334
MultiAbilityTracker([
3435
lh_grid_tracker,

src/aggregators/ability_tracker.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ end
6666
include("./ability_trackers/grid.jl")
6767
include("./ability_trackers/point.jl")
6868
include("./ability_trackers/closed_form_normal.jl")
69+
include("./ability_trackers/laplace.jl")
6970
include("./ability_trackers/multi.jl")
7071

7172
"""
Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
1-
struct LaplaceAbilityTracker <: AbilityTracker
2-
cur_ability::VarNormal
1+
struct LaplaceAbilityTracker{AbilityEstimatorT <: DistributionAbilityEstimator} <: AbilityTracker
2+
ability_estimator::AbilityEstimatorT
3+
optimizer::OneDimOptimOptimizer
4+
cur_ability::Union{Normal, Nothing}
35
end
46

7+
function LaplaceAbilityTracker(ability_estimator, optimizer)
8+
@warn "LaplaceAbilityTracker is a work in progress, and will not accelerate anything yet."
9+
LaplaceAbilityTracker(ability_estimator, optimizer, nothing)
10+
end
511

6-
{AbilityEstimatorT <: PointAbilityEstimator}
12+
function track!(responses, ability_tracker::LaplaceAbilityTracker)
13+
f(x) = pdf(ability_tracker.ability_estimator, responses, x)
14+
mode = ability_tracker.optimizer(f)
15+
stddev = -(ForwardDiff.hessian(f, mode) ^ (-1))
16+
ability_tracker.cur_ability = Normal(mode, stddev)
17+
end

0 commit comments

Comments
 (0)