Skip to content

Commit 385052f

Browse files
authored
Mirt redux (#72)
* Add postposterior covariance matrix based mirt rules * Qualify usage of even_grid in dt test * Formatting of ability_estimator * Add todo note to comparison.jl * Refactor dispersion around ScalarizedStateCriteron * Apply formatting
1 parent ec41675 commit 385052f

File tree

10 files changed

+348
-108
lines changed

10 files changed

+348
-108
lines changed

src/Comparison.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ function run_comparison(comparison::CatComparisonConfig{ReplayResponsesExecution
263263
items_answered = items_answered
264264
)
265265
if :after_item_criteria in comparison.phases
266+
# TOOD: Combine with next_item if possible and requested?
266267
timed_item_criteria = @timed Stateful.item_criteria(cat)
267268
measure_all(
268269
comparison,

src/aggregators/Aggregators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using PsychometricsBazaarBase.ConfigTools
1919
using PsychometricsBazaarBase.Integrators
2020
using PsychometricsBazaarBase: Integrators
2121
using PsychometricsBazaarBase.Optimizers
22-
using PsychometricsBazaarBase.ConstDistributions: std_normal
22+
using PsychometricsBazaarBase.ConstDistributions: std_normal, std_mv_normal
2323

2424
import FittedItemBanks
2525
import PsychometricsBazaarBase.IntegralCoeffs

src/aggregators/ability_estimator.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ struct PriorAbilityEstimator{PriorT <: Distribution} <: DistributionAbilityEstim
2828
prior::PriorT
2929
end
3030

31-
PriorAbilityEstimator() = PriorAbilityEstimator(std_normal)
31+
function PriorAbilityEstimator(; ncomp = 0)
32+
if ncomp == 0
33+
return PriorAbilityEstimator(std_normal)
34+
else
35+
return PriorAbilityEstimator(std_mv_normal(ncomp))
36+
end
37+
end
3238

3339
function pdf(est::PriorAbilityEstimator,
3440
tracked_responses::TrackedResponses)
@@ -73,6 +79,21 @@ function mean_1d(integrator::AbilityIntegrator,
7379
denom)
7480
end
7581

82+
function mean(
83+
integrator::AbilityIntegrator,
84+
est::DistributionAbilityEstimator,
85+
tracked_responses::TrackedResponses,
86+
denom = normdenom(integrator, est, tracked_responses)
87+
)
88+
n = domdims(tracked_responses.item_bank)
89+
expectation(IntegralCoeffs.id,
90+
n,
91+
integrator,
92+
est,
93+
tracked_responses,
94+
denom)
95+
end
96+
7697
function variance_given_mean(integrator::AbilityIntegrator,
7798
est::DistributionAbilityEstimator,
7899
tracked_responses::TrackedResponses,
@@ -97,6 +118,36 @@ function variance(integrator::AbilityIntegrator,
97118
denom)
98119
end
99120

121+
function covariance_matrix_given_mean(
122+
integrator::AbilityIntegrator,
123+
est::DistributionAbilityEstimator,
124+
tracked_responses::TrackedResponses,
125+
mean,
126+
denom = normdenom(integrator, est, tracked_responses)
127+
)
128+
n = domdims(tracked_responses.item_bank)
129+
expectation(IntegralCoeffs.OuterProdDev(mean),
130+
n,
131+
integrator,
132+
est,
133+
tracked_responses,
134+
denom)
135+
end
136+
137+
function covariance_matrix(
138+
integrator::AbilityIntegrator,
139+
est::DistributionAbilityEstimator,
140+
tracked_responses::TrackedResponses,
141+
denom = normdenom(integrator, est, tracked_responses))
142+
covariance_matrix_given_mean(
143+
integrator,
144+
est,
145+
tracked_responses,
146+
mean(integrator, est, tracked_responses, denom),
147+
denom
148+
)
149+
end
150+
100151
struct ModeAbilityEstimator{
101152
DistEst <: DistributionAbilityEstimator,
102153
OptimizerT <: AbilityOptimizer

src/next_item_rules/NextItemRules.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,25 @@ import PsychometricsBazaarBase.IntegralCoeffs
2626
using FittedItemBanks
2727
using FittedItemBanks: item_params
2828
using ..Aggregators
29+
using ..Aggregators: covariance_matrix
2930

3031
using Distributions, Base.Threads, Base.Order, StaticArrays
3132
using ConstructionBase: constructorof
3233
import ForwardDiff
3334

3435
export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread
3536
export NextItemRule, ItemStrategyNextItemRule
36-
export UrryItemCriterion, InformationItemCriterion, DRuleItemCriterion, TRuleItemCriterion
37+
export UrryItemCriterion, InformationItemCriterion
3738
export RandomNextItemRule
3839
export ExhaustiveSearch1Ply
3940
export catr_next_item_aliases
4041
export preallocate
4142
export compute_criteria
4243
export PointResponseExpectation, DistributionResponseExpectation
44+
export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer
45+
export AbilityCovarianceStateCriteria, StateCriteria, ItemCriteria
46+
export InformationMatrixCriteria
47+
export ScalarizedStateCriteron, ScalarizedItemCriteron
4348

4449
"""
4550
$(TYPEDEF)
@@ -68,6 +73,7 @@ end
6873

6974
include("./random.jl")
7075
include("./information.jl")
76+
include("./information_special.jl")
7177
include("./objective_function.jl")
7278
include("./expectation.jl")
7379

@@ -197,6 +203,7 @@ function compute_criteria(
197203
compute_criteria(rule.criterion, responses, items)
198204
end
199205

206+
include("./mirt.jl")
200207
include("./aliases.jl")
201208
include("./preallocate.jl")
202209

src/next_item_rules/aliases.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ const mirtcat_next_item_aliases = Dict(
4040
# 'MEPV' for minimum expected posterior variance
4141
"MEPV" => _mirtcat_helper((bits, ability_estimator) -> ExpectationBasedItemCriterion(
4242
ability_estimator,
43-
AbilityVarianceStateCriterion(bits...)))
43+
AbilityVarianceStateCriterion(bits...))),
44+
"Drule" => _mirtcat_helper((bits, ability_estimator) -> ScalarizedItemCriteron(
45+
InformationMatrixCriteria(ability_estimator),
46+
DeterminantScalarizer())),
47+
"Trule" => _mirtcat_helper((bits, ability_estimator) -> ScalarizedItemCriteron(
48+
InformationMatrixCriteria(ability_estimator),
49+
TraceScalarizer()))
4450
)
4551

4652
# 'MLWI' for maximum likelihood weighted information

src/next_item_rules/information.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@ using FittedItemBanks: CdfMirtItemBank,
33
using FittedItemBanks: inner_item_response, norm_abil, y_offset, irf_size
44
using StatsFuns: logaddexp
55

6+
function log_resp_vec(ir::ItemResponse{<:TransferItemBank}, θ)
7+
= norm_abil(ir, θ)
8+
return SVector(
9+
logccdf(ir.item_bank.distribution, nθ),
10+
logcdf(ir.item_bank.distribution, nθ)
11+
)
12+
end
13+
14+
function log_resp(ir::ItemResponse{<:TransferItemBank}, resp, θ)
15+
logcdf(ir.item_bank.distribution, norm_abil(ir, θ))
16+
end
17+
618
function log_resp_vec(ir::ItemResponse{<:CdfMirtItemBank}, θ)
719
= norm_abil(ir, θ)
820
SVector(logccdf(ir.item_bank.distribution, nθ),
@@ -52,26 +64,21 @@ function log_resp(ir::ItemResponse{<:AnySlipOrGuessItemBank}, val, θ)
5264
log_transform_irf_y(ir, val, log_resp(inner_item_response(ir), val, θ))
5365
end
5466

55-
# How does this compare with expected_item_information. Speeds/accuracies?
56-
# TODO: Which response models is this valid for?
57-
# TODO: Citation/source for this equation
58-
# TODO: Do it in log space?
59-
function item_information(ir::ItemResponse, θ)
60-
# irθ_prime = ForwardDiff.derivative(ir, θ)
61-
irθ_prime = ForwardDiff.derivative(x -> resp(ir, x), θ)
62-
irθ = resp(ir, θ)
63-
if irθ_prime == 0.0
64-
return 0.0
65-
else
66-
return (irθ_prime * irθ_prime) / (irθ * (1 - irθ))
67-
end
68-
end
69-
7067
function vector_hessian(f, x, n)
7168
out = ForwardDiff.jacobian(x -> ForwardDiff.jacobian(f, x), x)
7269
return reshape(out, n, n, n)
7370
end
7471

72+
function double_derivative(f, x)
73+
ForwardDiff.derivative(x -> ForwardDiff.derivative(f, x), x)
74+
end
75+
76+
function expected_item_information(ir::ItemResponse, θ::Float64)
77+
exp_resp = resp_vec(ir, θ)
78+
= double_derivative((θ -> log_resp_vec(ir, θ)), θ)
79+
-sum(exp_resp .* d²)
80+
end
81+
7582
# TODO: Unclear whether this should be implemented with ExpectationBasedItemCriterion
7683
# TODO: This is not implementing DRule but postposterior DRule
7784
function expected_item_information(ir::ItemResponse, θ::Vector{Float64})
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#=
2+
This file contains some specialised ways to calculate information.
3+
For some models analytical solutions are possible for information.
4+
Most are simple applications of the chain rule
5+
However, I haven't taken a systematic approach yet yet.
6+
So these are just from equations in the literature.
7+
8+
There aren't really any type guards on these so its up to the caller to make sure they are using the right ones.
9+
=#
10+
11+
function alt_expected_1d_item_information(ir::ItemResponse, θ)
12+
"""
13+
This is a special case of the expected_item_information function for
14+
* 1-dimensional ability
15+
* Dichotomous items
16+
* It should be valid for at least up to the 3PL model, probably others too
17+
18+
TODO: citation
19+
"""
20+
# irθ_prime = ForwardDiff.derivative(ir, θ)
21+
irθ_prime = ForwardDiff.derivative(x -> resp(ir, x), θ)
22+
irθ = resp(ir, θ)
23+
if irθ_prime == 0.0
24+
return 0.0
25+
else
26+
return (irθ_prime * irθ_prime) / (irθ * (1 - irθ))
27+
end
28+
end
29+
30+
function alt_expected_mirt_item_information(ir::ItemResponse, θ)
31+
"""
32+
This is a special case of the expected_item_information function for
33+
* Multidimensional
34+
* Dichotomous items
35+
* It should be valid for at least up to the 3PL model, probably others too
36+
37+
TODO: citation
38+
"""
39+
irθ_prime = ForwardDiff.gradient(x -> resp(ir, x), θ)
40+
= resp(ir, θ)
41+
= 1 -
42+
(irθ_prime * irθ_prime') / (pθ * qθ)
43+
end
44+
45+
function alt_expected_mirt_3pl_item_information(ir::ItemResponse, θ)
46+
"""
47+
This is a special case of the expected_item_information function for
48+
* Multidimensional
49+
* Dichotomous items
50+
* 3PL model only
51+
52+
Mulder J, van der Linden WJ.
53+
Multidimensional Adaptive Testing with Optimal Design Criteria for Item Selection.
54+
Psychometrika. 2009 Jun;74(2):273-296. doi: 10.1007/s11336-008-9097-5.
55+
Equation 4
56+
"""
57+
# XXX: Should avoid using item_params
58+
params = item_params(ir.item_bank.discriminations, ir.index)
59+
= resp(ir, θ)
60+
= 1 -
61+
a = params.discrimination
62+
c = params.guess
63+
common_factor = (qθ * (pθ - c)^2) / (pθ * (1 - c)^2)
64+
common_factor * (a * a')
65+
end

src/next_item_rules/mirt.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
abstract type MatrixScalarizer end
2+
3+
struct DeterminantScalarizer <: MatrixScalarizer end
4+
(::DeterminantScalarizer)(mat) = det(mat)
5+
6+
struct TraceScalarizer <: MatrixScalarizer end
7+
(::TraceScalarizer)(mat) = tr(mat)
8+
9+
abstract type StateCriteria end
10+
abstract type ItemCriteria end
11+
12+
struct AbilityCovarianceStateCriteria{
13+
DistEstT <: DistributionAbilityEstimator,
14+
IntegratorT <: AbilityIntegrator
15+
} <: StateCriteria
16+
dist_est::DistEstT
17+
integrator::IntegratorT
18+
skip_zero::Bool
19+
end
20+
21+
function AbilityCovarianceStateCriteria(bits...)
22+
skip_zero = false
23+
@requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...)
24+
return AbilityCovarianceStateCriteria(dist_est, integrator, skip_zero)
25+
end
26+
27+
# XXX: Should be at type level
28+
should_minimize(::AbilityCovarianceStateCriteria) = true
29+
30+
function (criteria::AbilityCovarianceStateCriteria)(
31+
tracked_responses::TrackedResponses,
32+
denom = normdenom(criteria.integrator,
33+
criteria.dist_est,
34+
tracked_responses)
35+
)
36+
if denom == 0.0 && criteria.skip_zero
37+
return Inf
38+
end
39+
covariance_matrix(
40+
criteria.integrator,
41+
criteria.dist_est,
42+
tracked_responses,
43+
denom
44+
)
45+
end
46+
47+
struct ScalarizedStateCriteron{
48+
StateCriteriaT <: StateCriteria,
49+
MatrixScalarizerT <: MatrixScalarizer
50+
} <: StateCriterion
51+
criteria::StateCriteriaT
52+
scalarizer::MatrixScalarizerT
53+
end
54+
55+
function (ssc::ScalarizedStateCriteron)(tracked_responses)
56+
res = ssc.criteria(tracked_responses) |> ssc.scalarizer
57+
if !should_minimize(ssc.criteria)
58+
res = -res
59+
end
60+
res
61+
end
62+
63+
struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: ItemCriteria
64+
ability_estimator::AbilityEstimatorT
65+
expected_item_information::F
66+
end
67+
68+
function InformationMatrixCriteria(ability_estimator)
69+
InformationMatrixCriteria(ability_estimator, expected_item_information)
70+
end
71+
72+
function init_thread(item_criterion::InformationMatrixCriteria,
73+
responses::TrackedResponses)
74+
# TODO: No need to do this one per thread. It just need to be done once per
75+
# θ update.
76+
# TODO: Update this to use track!(...) mechanism
77+
ability = maybe_tracked_ability_estimate(responses, item_criterion.ability_estimator)
78+
responses_information(responses.item_bank, responses.responses, ability)
79+
end
80+
81+
function (item_criterion::InformationMatrixCriteria)(acc_info::Matrix{Float64},
82+
tracked_responses::TrackedResponses,
83+
item_idx)
84+
# TODO: Add in information from the prior
85+
ability = maybe_tracked_ability_estimate(
86+
tracked_responses, item_criterion.ability_estimator)
87+
return acc_info .+
88+
item_criterion.expected_item_information(
89+
ItemResponse(tracked_responses.item_bank, item_idx), ability)
90+
end
91+
92+
should_minimize(::InformationMatrixCriteria) = false
93+
94+
struct ScalarizedItemCriteron{
95+
ItemCriteriaT <: ItemCriteria,
96+
MatrixScalarizerT <: MatrixScalarizer
97+
} <: ItemCriterion
98+
criteria::ItemCriteriaT
99+
scalarizer::MatrixScalarizerT
100+
end
101+
102+
function (ssc::ScalarizedItemCriteron)(tracked_responses, item_idx)
103+
res = ssc.criteria(
104+
init_thread(ssc.criteria, tracked_responses), tracked_responses, item_idx) |>
105+
ssc.scalarizer
106+
if !should_minimize(ssc.criteria)
107+
res = -res
108+
end
109+
res
110+
end
111+
112+
struct WeightedStateCriteria{InnerT <: StateCriteria} <: StateCriteria
113+
weights::Vector{Float64}
114+
criteria::InnerT
115+
end
116+
117+
function (wsc::WeightedStateCriteria)(tracked_responses, item_idx)
118+
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
119+
end
120+
121+
struct WeightedItemCriteria{InnerT <: ItemCriteria} <: ItemCriteria
122+
weights::Vector{Float64}
123+
criteria::InnerT
124+
end
125+
126+
function (wsc::WeightedItemCriteria)(tracked_responses, item_idx)
127+
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
128+
end

0 commit comments

Comments
 (0)