Skip to content

Commit 234d623

Browse files
authored
Merge pull request #79 from JuliaPsychometricsBazaar/stateful-improvements
Stateful improvements
2 parents 27ae85f + 0bb3741 commit 234d623

File tree

9 files changed

+215
-31
lines changed

9 files changed

+215
-31
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3131
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3232
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
3333

34+
[weakdeps]
35+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
36+
37+
[extensions]
38+
TestExt = "Test"
39+
3440
[compat]
3541
Accessors = "^0.1.12"
3642
Aqua = "0.8"

ext/TestExt.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
module TestExt
2+
3+
using Test
4+
using ComputerAdaptiveTesting: Stateful
5+
using FittedItemBanks: AbstractItemBank, ItemResponse, resp
6+
7+
export test_stateful_cat_1d_dich_ib, test_stateful_cat_item_bank_1d_dich_ib
8+
9+
function test_stateful_cat_1d_dich_ib(
10+
cat::Stateful.StatefulCat,
11+
item_bank_length;
12+
supports_ranked_and_criteria = true,
13+
supports_rollback = true
14+
)
15+
if item_bank_length < 3
16+
error("Item bank length must be at least 3.")
17+
end
18+
@testset "response round trip" begin
19+
responses_before = Stateful.get_responses(cat)
20+
@test length(responses_before.indices) == 0
21+
@test length(responses_before.values) == 0
22+
23+
Stateful.add_response!(cat, 1, false)
24+
Stateful.add_response!(cat, 2, true)
25+
26+
responses_after_add = Stateful.get_responses(cat)
27+
@test length(responses_after_add.indices) == 2
28+
@test length(responses_after_add.values) == 2
29+
30+
Stateful.reset!(cat)
31+
responses_after_reset = Stateful.get_responses(cat)
32+
@test length(responses_after_reset.indices) == 0
33+
@test length(responses_after_reset.values) == 0
34+
end
35+
36+
# Test the next_item function
37+
@testset "basic next_item tests" begin
38+
Stateful.add_response!(cat, 1, false)
39+
Stateful.add_response!(cat, 2, true)
40+
41+
item = Stateful.next_item(cat)
42+
@test isa(item, Integer)
43+
@test item >= 1
44+
@test item >= 3
45+
@test item <= item_bank_length
46+
end
47+
48+
if supports_ranked_and_criteria
49+
@testset "basic ranked/criteria tests" begin
50+
items = Stateful.ranked_items(cat)
51+
@test length(items) == item_bank_length
52+
53+
criteria = Stateful.item_criteria(cat)
54+
@test length(criteria) == item_bank_length
55+
end
56+
end
57+
58+
if supports_rollback
59+
@testset "basic rollback tests" begin
60+
Stateful.reset!(cat)
61+
Stateful.add_response!(cat, 1, false)
62+
Stateful.add_response!(cat, 2, true)
63+
Stateful.rollback!(cat)
64+
responses_after_rollback = Stateful.get_responses(cat)
65+
@test length(responses_after_rollback.indices) == 1
66+
@test length(responses_after_rollback.values) == 1
67+
end
68+
end
69+
70+
@testset "basic get_ability tests" begin
71+
Stateful.reset!(cat)
72+
Stateful.add_response!(cat, 1, false)
73+
Stateful.add_response!(cat, 2, true)
74+
ability = Stateful.get_ability(cat)
75+
@test isa(ability, Tuple)
76+
@test length(ability) == 2
77+
@test isa(ability[1], Float64)
78+
end
79+
80+
if supports_rollback
81+
@testset "rollback ability tests" begin
82+
Stateful.reset!(cat)
83+
Stateful.add_response!(cat, 1, false)
84+
ability1 = Stateful.get_ability(cat)
85+
Stateful.add_response!(cat, 2, true)
86+
ability2 = Stateful.get_ability(cat)
87+
Stateful.rollback!(cat)
88+
@test Stateful.get_ability(cat) == ability1
89+
Stateful.add_response!(cat, 2, true)
90+
@test Stateful.get_ability(cat) == ability2
91+
end
92+
end
93+
end
94+
95+
function test_stateful_cat_item_bank_1d_dich_ib(
96+
cat::Stateful.StatefulCat,
97+
item_bank::AbstractItemBank,
98+
points=[-.78, 0.0, .78],
99+
margin=0.05,
100+
)
101+
if length(item_bank) != Stateful.item_bank_size(cat)
102+
error("Item bank length does not match the cat's item bank size.")
103+
end
104+
for i in 1:length(item_bank)
105+
for point in points
106+
cat_prob = Stateful.item_response_function(cat, i, true, point)
107+
ib_prob = resp(ItemResponse(item_bank, i), true, point)
108+
@test cat_prob ib_prob rtol=margin
109+
end
110+
end
111+
end
112+
113+
end

src/ComputerAdaptiveTesting.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ export NextItemRules, TerminationConditions
1010
export CatConfig, Sim, DecisionTree
1111
export Stateful, Comparison
1212

13+
# Extension modules
14+
public require_testext
15+
1316
# Vendored dependencies
1417
include("./vendor/PushVectors.jl")
1518

@@ -44,4 +47,12 @@ include("./Comparison.jl")
4447

4548
include("./precompiles.jl")
4649

50+
function require_testext()
51+
TestExt = Base.get_extension(@__MODULE__, :TestExt)
52+
if TestExt === nothing
53+
error("Failed to load extension module TestExt.")
54+
end
55+
return TestExt
4756
end
57+
58+
end

src/Responses.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using FittedItemBanks: AbstractItemBank,
77
using AutoHashEquals: @auto_hash_equals
88

99
export Response, BareResponses, AbilityLikelihood, function_xs, function_ys
10+
export add_response!, pop_response!
1011

1112
concrete_response_type(::BooleanResponse) = Bool
1213
concrete_response_type(::MultinomialResponse) = Int
@@ -69,6 +70,28 @@ function Base.iterate(::BareResponses, gen_gen_state)
6970
return _iter_helper(gen, iterate(gen, gen_state))
7071
end
7172

73+
function Base.empty!(responses::BareResponses)
74+
Base.empty!(responses.indices)
75+
Base.empty!(responses.values)
76+
end
77+
78+
function add_response!(responses::BareResponses, response::Response)::BareResponses
79+
push!(responses.indices, response.index)
80+
push!(responses.values, response.value)
81+
responses
82+
end
83+
84+
function pop_response!(responses::BareResponses)::BareResponses
85+
pop!(responses.indices)
86+
pop!(responses.values)
87+
responses
88+
end
89+
90+
function Base.sizehint!(bare_responses::BareResponses, n)
91+
sizehint!(bare_responses.indices, n)
92+
sizehint!(bare_responses.values, n)
93+
end
94+
7295
struct AbilityLikelihood{ItemBankT <: AbstractItemBank, BareResponsesT <: BareResponses}
7396
item_bank::ItemBankT
7497
responses::BareResponsesT

src/Stateful.jl

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ module Stateful
77

88
using DocStringExtensions
99

10-
using FittedItemBanks: AbstractItemBank, ResponseType
10+
using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, resp
1111
using ..Aggregators: TrackedResponses, Aggregators
1212
using ..CatConfig: CatLoopConfig, CatRules
13-
using ..Responses: BareResponses, Response
13+
using ..Responses: BareResponses, Response, Responses
1414
using ..NextItemRules: compute_criteria, best_item
1515
using ..Sim: Sim, item_label
1616

@@ -124,6 +124,25 @@ but should attempt to interoperate with ComputerAdaptiveTesting.jl.
124124
"""
125125
function get_ability end
126126

127+
"""
128+
```julia
129+
$(FUNCTIONNAME)(config::StatefulCat)
130+
````
131+
132+
Return number of items in the current item bank.
133+
"""
134+
function item_bank_size end
135+
136+
"""
137+
```julia
138+
$(FUNCTIONNAME)(config::StatefulCat, index::IndexT, response::ResponseT, ability::AbilityT) -> Float
139+
````
140+
141+
Return the probability of a `response` to item at `index` for someone with
142+
a certain `ability` according to the IRT model backing the CAT.
143+
"""
144+
function item_response_function end
145+
127146
## Running the CAT
128147
function Sim.run_cat(cat_config::CatLoopConfig{RulesT},
129148
ib_labels = nothing) where {RulesT <: StatefulCat}
@@ -190,13 +209,13 @@ end
190209

191210
function add_response!(config::StatefulCatConfig, index, response)
192211
tracked_responses = config.tracked_responses[]
193-
Aggregators.add_response!(
212+
Responses.add_response!(
194213
tracked_responses, Response(
195214
ResponseType(tracked_responses.item_bank), index, response))
196215
end
197216

198217
function rollback!(config::StatefulCatConfig)
199-
pop_response!(config.tracked_responses[])
218+
Responses.pop_response!(config.tracked_responses[])
200219
end
201220

202221
function reset!(config::StatefulCatConfig)
@@ -220,6 +239,16 @@ function get_ability(config::StatefulCatConfig)
220239
return (config.rules.ability_estimator(config.tracked_responses[]), nothing)
221240
end
222241

242+
function item_bank_size(config::StatefulCatConfig)
243+
return length(config.tracked_responses[].item_bank)
244+
end
245+
246+
function item_response_function(config::StatefulCatConfig, index, response, ability)
247+
item_bank = config.tracked_responses[].item_bank
248+
item_response = ItemResponse(item_bank, index)
249+
return resp(item_response, response, ability)
250+
end
251+
223252
## TODO: Implementation for MaterializedDecisionTree
224253

225254
end

src/aggregators/Aggregators.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using FittedItemBanks: AbstractItemBank, ContinuousDomain,
1717
PointsItemBank, ResponseType, VectorContinuousDomain,
1818
domdims, item_params, resp, resp_vec, responses
1919
using ..Responses
20-
using ..Responses: concrete_response_type, function_xs, function_ys
20+
using ..Responses: concrete_response_type, function_xs, function_ys, Responses
2121
using ..ConfigBase
2222
using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome,
2323
find1_instance, find1_type,
@@ -37,8 +37,7 @@ import PsychometricsBazaarBase.IntegralCoeffs
3737
export AbilityEstimator, TrackedResponses
3838
export AbilityTracker, NullAbilityTracker, PointAbilityTracker, GriddedAbilityTracker
3939
export ClosedFormNormalAbilityTracker, track!
40-
export response_expectation,
41-
add_response!, pop_response!, expectation, distribution_estimator
40+
export response_expectation, expectation, distribution_estimator
4241
export PointAbilityEstimator, PriorAbilityEstimator, LikelihoodAbilityEstimator
4342
export ModeAbilityEstimator, MeanAbilityEstimator
4443
export Speculator, replace_speculation!, normdenom, maybe_tracked_ability_estimate

src/aggregators/ability_tracker.jl

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,19 @@
1-
function sizehint!(bare_responses::BareResponses, n)
2-
sizehint!(bare_responses.indices, n)
3-
sizehint!(bare_responses.values, n)
4-
end
5-
61
function track!(responses)
72
track!(responses, responses.ability_tracker)
83
end
94

10-
function add_response!(responses::BareResponses, response::Response)::BareResponses
11-
push!(responses.indices, response.index)
12-
push!(responses.values, response.value)
13-
responses
14-
end
15-
16-
function add_response!(tracked_responses::TrackedResponses, response::Response)
5+
function Responses.add_response!(tracked_responses::TrackedResponses, response::Response)
176
add_response!(tracked_responses.responses, response)
187
track!(tracked_responses)
198
end
209

21-
function pop_response!(responses::BareResponses)::BareResponses
22-
pop!(responses.indices)
23-
pop!(responses.values)
24-
responses
25-
end
26-
27-
function pop_response!(tracked_responses::TrackedResponses)::TrackedResponses
10+
function Responses.pop_response!(tracked_responses::TrackedResponses)::TrackedResponses
2811
pop_response!(tracked_responses.responses)
2912
tracked_responses
3013
end
3114

3215
function Base.empty!(tracked_responses::TrackedResponses)
33-
Base.empty!(tracked_responses.responses.indices)
34-
Base.empty!(tracked_responses.responses.values)
16+
Base.empty!(tracked_responses.responses)
3517
end
3618

3719
function response_expectation(ability_estimator::DistributionAbilityEstimator,

src/decision_tree/DecisionTree.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using ComputerAdaptiveTesting.ConfigBase: CatConfigBase
66
using ComputerAdaptiveTesting.PushVectors
77
using ComputerAdaptiveTesting.NextItemRules
88
using ComputerAdaptiveTesting.Aggregators
9-
using ComputerAdaptiveTesting.Responses: BareResponses, Response
9+
using ComputerAdaptiveTesting.Responses: BareResponses, Response, add_response!, pop_response!
1010
using FittedItemBanks: AbstractItemBank, BooleanResponse, ResponseType
1111

1212
# TODO: Remove ability tracking from here?

test/stateful.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition
77
using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule
88
using ComputerAdaptiveTesting: Stateful
9+
using ComputerAdaptiveTesting: require_testext
910
using ResumableFunctions
1011
using Test: @test, @testset
1112

@@ -26,7 +27,7 @@
2627
@testset "StatefulCatConfig basic usage" begin
2728
rules = CatRules(
2829
FixedItemsTerminationCondition(2),
29-
Dummy.DummyAbilityEstimator(0),
30+
Dummy.DummyAbilityEstimator(0.0),
3031
RandomNextItemRule()
3132
)
3233

@@ -54,7 +55,7 @@
5455
@testset "Stateful next item selection" begin
5556
rules = CatRules(
5657
FixedItemsTerminationCondition(2),
57-
Dummy.DummyAbilityEstimator(0),
58+
Dummy.DummyAbilityEstimator(0.0),
5859
RandomNextItemRule()
5960
)
6061
cat_config = Stateful.StatefulCatConfig(rules, item_bank)
@@ -69,4 +70,24 @@
6970
@test 1 <= second_item <= 4
7071
@test second_item != first_item # Should select different item
7172
end
73+
74+
@testset "Standard interface tests" begin
75+
rules = CatRules(
76+
FixedItemsTerminationCondition(2),
77+
Dummy.DummyAbilityEstimator(0.0),
78+
RandomNextItemRule()
79+
)
80+
81+
# Initialize config
82+
cat_config = Stateful.StatefulCatConfig(rules, item_bank)
83+
84+
# Run the standard interface tests
85+
TestExt = require_testext()
86+
TestExt.test_stateful_cat_1d_dich_ib(
87+
cat_config,
88+
4;
89+
supports_ranked_and_criteria = false,
90+
)
91+
TestExt.test_stateful_cat_item_bank_1d_dich_ib(cat_config, item_bank)
92+
end
7293
end

0 commit comments

Comments
 (0)