Skip to content

Commit 1a6b7d2

Browse files
author
Frankie Robertson
committed
Various changes
* Add pointwise/category information criterion * Make a bunch of stuff generic for number * Other stuff too
1 parent 19c0db0 commit 1a6b7d2

File tree

20 files changed

+385
-190
lines changed

20 files changed

+385
-190
lines changed

src/aggregators/Aggregators.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using StaticArrays: SVector
1010
using Distributions: Distribution, Normal, Distributions
1111
using Base.Threads
1212
using ForwardDiff: ForwardDiff
13+
using LogarithmicNumbers: Logarithmic, ULogarithmic
1314

1415
using FittedItemBanks: AbstractItemBank, ContinuousDomain,
1516
DichotomousSmoothedItemBank, DiscreteIndexableDomain,
@@ -24,12 +25,14 @@ using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome,
2425
find1_type_sloppy
2526
using PsychometricsBazaarBase.Integrators: Integrators,
2627
BareIntegrationResult,
27-
FixedGridIntegrator, IntReturnType,
28+
FixedGridIntegrator,
29+
IntReturnType,
2830
IntValue, Integrator,
2931
PreallocatedFixedGridIntegrator,
3032
normdenom
3133
using PsychometricsBazaarBase.Optimizers: OneDimOptimOptimizer, Optimizer
3234
using PsychometricsBazaarBase.ConstDistributions: std_normal, std_mv_normal
35+
import Distributions: pdf
3336

3437
import FittedItemBanks
3538
import PsychometricsBazaarBase.IntegralCoeffs
@@ -38,7 +41,8 @@ export AbilityEstimator, TrackedResponses
3841
export AbilityTracker, NullAbilityTracker, PointAbilityTracker, GriddedAbilityTracker
3942
export ClosedFormNormalAbilityTracker, track!
4043
export response_expectation, expectation, distribution_estimator
41-
export PointAbilityEstimator, PriorAbilityEstimator, LikelihoodAbilityEstimator
44+
export PointAbilityEstimator, PriorAbilityEstimator
45+
export SafeLikelihoodAbilityEstimator, LikelihoodAbilityEstimator
4246
export ModeAbilityEstimator, MeanAbilityEstimator
4347
export Speculator, replace_speculation!, normdenom, maybe_tracked_ability_estimate
4448
export AbilityIntegrator, AbilityOptimizer
@@ -70,6 +74,10 @@ end
7074
abstract type DistributionAbilityEstimator <: AbilityEstimator end
7175
function DistributionAbilityEstimator(bits...)
7276
@returnsome find1_instance(DistributionAbilityEstimator, bits)
77+
point_ability_estimator = find1_instance(PointAbilityEstimator, bits)
78+
if point_ability_estimator !== nothing
79+
return distribution_estimator(point_ability_estimator)
80+
end
7381
end
7482

7583
abstract type PointAbilityEstimator <: AbilityEstimator end

src/aggregators/optimizers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,6 @@ function (optim::AbilityOptimizer)(f::F,
4747
est,
4848
tracked_responses::TrackedResponses;
4949
kwargs...) where {F}
50-
optim(maybe_apply_prior(f, est), AbilityLikelihood(tracked_responses); kwargs...)
50+
#optim(maybe_apply_prior(f, est), AbilityLikelihood(tracked_responses); kwargs...)
51+
optim(f, pdf(est, tracked_responses); kwargs...)
5152
end

src/aggregators/riemann.jl

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,32 @@ function (integrator::RiemannEnumerationIntegrator)(f::F,
2626
return BareIntegrationResult(result)
2727
end
2828

29-
function (integrator::Union{RiemannEnumerationIntegrator, FunctionIntegrator})(f::F,
30-
ncomp,
31-
est,
32-
tracked_responses::TrackedResponses;
33-
kwargs...) where {F}
34-
integrator(maybe_apply_prior(f, est),
29+
function (integrator::RiemannEnumerationIntegrator)(
30+
f::F,
31+
ncomp,
32+
est,
33+
tracked_responses::TrackedResponses;
34+
kwargs...
35+
) where {F}
36+
integrator(
37+
maybe_apply_prior(f, est),
3538
ncomp,
3639
AbilityLikelihood(tracked_responses);
37-
kwargs...)
40+
kwargs...
41+
)
42+
end
43+
44+
function (integrator::FunctionIntegrator)(
45+
f::F,
46+
ncomp,
47+
est,
48+
tracked_responses::TrackedResponses;
49+
kwargs...
50+
) where {F}
51+
integrator(
52+
f,
53+
ncomp,
54+
pdf(est, tracked_responses);
55+
kwargs...
56+
)
3857
end

src/logitembank.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@ inner_ir(ir::ItemResponse{<:LogItemBank}) = ItemResponse(ir.item_bank.inner, ir.
2121
## TODO: Support item banks with other response types e.g. Float32
2222

2323
function FittedItemBanks.resp(ir::ItemResponse{<:LogItemBank}, θ)
24-
exp(ULogarithmic{Float64}, FittedItemBanks.log_resp(inner_ir(ir), θ))
24+
exp(ULogarithmic, FittedItemBanks.log_resp(inner_ir(ir), θ))
2525
end
2626

2727
function FittedItemBanks.resp(ir::ItemResponse{<:LogItemBank}, response, θ)
2828
exp(
29-
ULogarithmic{Float64},
29+
ULogarithmic,
3030
FittedItemBanks.log_resp(inner_ir(ir), response, θ)
3131
)
3232
end
3333

3434
function FittedItemBanks.resp_vec(ir::ItemResponse{<:LogItemBank}, θ)
35-
exp.(ULogarithmic{Float64}, FittedItemBanks.log_resp_vec(inner_ir(ir), θ))
35+
exp.(ULogarithmic, FittedItemBanks.log_resp_vec(inner_ir(ir), θ))
3636
end
3737

3838
@forward LogItemBank.inner Base.length,

src/next_item_rules/NextItemRules.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,22 @@ Springer, New York, NY.
1111
module NextItemRules
1212

1313
using DocStringExtensions: FUNCTIONNAME, TYPEDEF, TYPEDFIELDS
14-
using PsychometricsBazaarBase.Parameters: @with_kw
14+
using PsychometricsBazaarBase.Parameters
1515
using LinearAlgebra: det, tr
1616
using Random: AbstractRNG, Xoshiro
1717

1818
using ..Responses: BareResponses
1919
using ..ConfigBase
2020
using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome,
2121
find1_instance, find1_type
22-
using PsychometricsBazaarBase.Integrators: Integrator
22+
using PsychometricsBazaarBase.Integrators: Integrator, intval
2323
using PsychometricsBazaarBase: Integrators
2424
import PsychometricsBazaarBase.IntegralCoeffs
2525
using FittedItemBanks: AbstractItemBank, DiscreteDomain, DomainType,
2626
ItemResponse, OneDimContinuousDomain, domdims, item_params,
2727
resp, resp_vec, responses
2828
using ..Aggregators
29-
using ..Aggregators: covariance_matrix
29+
using ..Aggregators: covariance_matrix, FunctionProduct
3030

3131
using Distributions: logccdf, logcdf, pdf
3232
using Base.Threads
@@ -38,13 +38,17 @@ import ForwardDiff
3838
export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread
3939
export NextItemRule, ItemStrategyNextItemRule
4040
export UrryItemCriterion, InformationItemCriterion
41+
export LikelihoodWeightedItemCriterion, PointItemCriterion
42+
export LikelihoodWeightedItemCategoryCriterion, PointItemCategoryCriterion
43+
export ObservedInformationPointwiseItemCategoryCriterion
44+
export RawEmpiricalInformationPointwiseItemCategoryCriterion
45+
export EmpiricalInformationPointwiseItemCategoryCriterion
46+
export TotalItemInformation
4147
export RandomNextItemRule
4248
export PiecewiseNextItemRule, MemoryNextItemRule, FixedFirstItemNextItemRule
4349
export ExhaustiveSearch
44-
export catr_next_item_aliases
4550
export preallocate
46-
export compute_criteria, compute_criterion, compute_multi_criterion,
47-
compute_pointwise_criterion
51+
export compute_criteria, compute_criterion, compute_multi_criterion
4852
export best_item
4953
export PointResponseExpectation, DistributionResponseExpectation
5054
export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer
@@ -70,15 +74,15 @@ include("./combinators/scalarizers.jl")
7074
include("./combinators/likelihood.jl")
7175

7276
# Criteria
73-
include("./criteria/item/information_special.jl")
74-
include("./criteria/item/information_support.jl")
7577
include("./criteria/item/information.jl")
7678
include("./criteria/item/urry.jl")
7779
include("./criteria/state/ability_variance.jl")
80+
include("./criteria/pointwise/information_special.jl")
81+
include("./criteria/pointwise/information_support.jl")
82+
include("./criteria/pointwise/information.jl")
7883
include("./criteria/pointwise/kl.jl")
7984

8085
# Porcelain
8186
include("./porcelain/porcelain.jl")
82-
include("./porcelain/aliases.jl")
8387

8488
end

src/next_item_rules/combinators/expectation.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,16 @@ item 1-ply ahead.
6767
"""
6868
struct ExpectationBasedItemCriterion{
6969
ResponseExpectationT <: ResponseExpectation,
70-
CriterionT <: Union{StateCriterion, ItemCriterion}
70+
CriterionT <: Union{StateCriterion, ItemCriterion, ItemCategoryCriterion},
7171
} <: ItemCriterion
7272
response_expectation::ResponseExpectationT
7373
criterion::CriterionT
7474
end
7575

7676
function _get_some_criterion(bits...; kwargs...)
7777
@returnsome StateCriterion(bits...; kwargs...)
78-
@returnsome ItemCriterion(bits...; kwargs...)
78+
@returnsome ItemCriterion(bits...; skip_expectation=true, kwargs...)
79+
@returnsome ItemCategoryCriterion(bits...)
7980
end
8081

8182
function ExpectationBasedItemCriterion(bits...;
@@ -95,13 +96,16 @@ function init_thread(::ExpectationBasedItemCriterion, responses::TrackedResponse
9596
Speculator(responses, 1)
9697
end
9798

98-
function _generic_criterion(criterion::StateCriterion, tracked_responses, item_idx)
99+
function _generic_criterion(criterion::StateCriterion, tracked_responses, _item_idx, _response)
99100
compute_criterion(criterion, tracked_responses)
100101
end
101102
# TODO: Support init_thread for wrapped ItemCriterion
102-
function _generic_criterion(criterion::ItemCriterion, tracked_responses, item_idx)
103+
function _generic_criterion(criterion::ItemCriterion, tracked_responses, item_idx, _response)
103104
compute_criterion(criterion, tracked_responses, item_idx)
104105
end
106+
function _generic_criterion(criterion::ItemCategoryCriterion, tracked_responses, item_idx, response)
107+
compute_criterion(criterion, tracked_responses, item_idx, response)
108+
end
105109

106110
function compute_criterion(
107111
item_criterion::ExpectationBasedItemCriterion,
@@ -116,7 +120,7 @@ function compute_criterion(
116120
for (prob, possible_response) in zip(exp_resp, possible_responses)
117121
replace_speculation!(speculator, SVector(item_idx), SVector(possible_response))
118122
res += prob *
119-
_generic_criterion(item_criterion.criterion, speculator.responses, item_idx)
123+
_generic_criterion(item_criterion.criterion, speculator.responses, item_idx, possible_response)
120124
end
121125
res
122126
end

src/next_item_rules/combinators/likelihood.jl

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,90 @@ struct LikelihoodWeightedItemCriterion{
88
estimator::AbilityEstimatorT
99
end
1010

11+
function LikelihoodWeightedItemCriterion(bits...)
12+
@requiresome dist_est_integrator_pair = get_dist_est_and_integrator(bits...)
13+
(dist_est, integrator) = dist_est_integrator_pair
14+
criterion = PointwiseItemCriterion(bits...)
15+
return LikelihoodWeightedItemCriterion(criterion, integrator, dist_est)
16+
end
17+
1118
function compute_criterion(
12-
lwic::LikelihoodWeightedItemCriterion,
13-
tracked_responses::TrackedResponses,
14-
item_idx
19+
lwic::LikelihoodWeightedItemCriterion,
20+
tracked_responses::TrackedResponses,
21+
item_idx
1522
)
1623
func = FunctionProduct(
17-
pdf(lwic.estimator, tracked_responses), lwic.criterion(tracked_responses, item_idx))
18-
lwic.integrator(func, 0, lwic.estimator, tracked_responses)
24+
pdf(lwic.estimator, tracked_responses), ability -> compute_criterion(lwic.criterion, tracked_responses, item_idx, ability))
25+
intval(lwic.integrator(func, 0, lwic.estimator, tracked_responses))
26+
end
27+
28+
struct PointItemCriterion{
29+
PointwiseItemCriterionT <: PointwiseItemCriterion,
30+
AbilityEstimatorT <: PointAbilityEstimator
31+
} <: ItemCriterion
32+
criterion::PointwiseItemCriterionT
33+
estimator::AbilityEstimatorT
1934
end
35+
36+
function compute_criterion(
37+
pic::PointItemCriterion,
38+
tracked_responses::TrackedResponses,
39+
item_idx
40+
)
41+
ability = maybe_tracked_ability_estimate(
42+
tracked_responses,
43+
pic.estimator
44+
)
45+
return compute_criterion(pic.criterion, tracked_responses, item_idx, ability)
46+
end
47+
48+
struct LikelihoodWeightedItemCategoryCriterion{
49+
PointwiseItemCategoryCriterionT <: PointwiseItemCategoryCriterion,
50+
AbilityIntegratorT <: AbilityIntegrator,
51+
AbilityEstimatorT <: DistributionAbilityEstimator
52+
} <: ItemCategoryCriterion
53+
criterion::PointwiseItemCategoryCriterionT
54+
integrator::AbilityIntegratorT
55+
estimator::AbilityEstimatorT
56+
end
57+
58+
function LikelihoodWeightedItemCategoryCriterion(bits...)
59+
@requiresome dist_est_integrator_pair = get_dist_est_and_integrator(bits...)
60+
(dist_est, integrator) = dist_est_integrator_pair
61+
criterion = PointwiseItemCategoryCriterion(bits...)
62+
return LikelihoodWeightedItemCategoryCriterion(criterion, integrator, dist_est)
63+
end
64+
65+
function compute_criterion(
66+
lwicc::LikelihoodWeightedItemCategoryCriterion,
67+
tracked_responses::TrackedResponses,
68+
item_idx,
69+
category
70+
)
71+
func = FunctionProduct(
72+
pdf(lwicc.estimator, tracked_responses),
73+
ability -> compute_criterion(lwicc.criterion, tracked_responses, item_idx, ability, category)
74+
)
75+
intval(lwicc.integrator(func, 0, lwicc.estimator, tracked_responses))
76+
end
77+
78+
struct PointItemCategoryCriterion{
79+
PointwiseItemCategoryCriterionT <: PointwiseItemCategoryCriterion,
80+
AbilityEstimatorT <: PointAbilityEstimator
81+
} <: ItemCategoryCriterion
82+
criterion::PointwiseItemCategoryCriterionT
83+
estimator::AbilityEstimatorT
84+
end
85+
86+
function compute_criterion(
87+
pic::PointItemCategoryCriterion,
88+
tracked_responses::TrackedResponses,
89+
item_idx,
90+
category
91+
)
92+
ability = maybe_tracked_ability_estimate(
93+
tracked_responses,
94+
pic.estimator
95+
)
96+
return compute_criterion(pic.criterion, tracked_responses, item_idx, ability, category)
97+
end

src/next_item_rules/criteria/item/information.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
# TODO: Should have Variants for point ability versus distribution ability
2-
struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <:
2+
@kw_only struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <:
33
ItemCriterion
44
ability_estimator::AbilityEstimatorT
55
expected_item_information::F
66
end
77

8-
function InformationItemCriterion(ability_estimator)
9-
InformationItemCriterion(ability_estimator, expected_item_information)
8+
function InformationItemCriterion(ability_estimator::PointAbilityEstimator)
9+
InformationItemCriterion(;
10+
ability_estimator,
11+
expected_item_information
12+
)
13+
end
14+
15+
function InformationItemCriterion(bits...)
16+
@requiresome ability_estimator = PointAbilityEstimator(bits...)
17+
InformationItemCriterion(ability_estimator)
1018
end
1119

1220
function compute_criterion(

src/next_item_rules/criteria/item/urry.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ struct UrryItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <: ItemCrit
99
ability_estimator::AbilityEstimatorT
1010
end
1111

12+
function UrryItemCriterion(bits...)
13+
@requiresome ability_estimator = PointAbilityEstimator(bits...)
14+
UrryItemCriterion(ability_estimator)
15+
end
16+
1217
# TODO: Slow + poor error handling
1318
function raw_difficulty(item_bank, item_idx)
1419
item_params(item_bank, item_idx).difficulty

0 commit comments

Comments
 (0)