Skip to content

Commit 73b933d

Browse files
author
Frankie Robertson
committed
Add TestExt for stateful interface test
1 parent 75947f2 commit 73b933d

File tree

4 files changed

+133
-2
lines changed

4 files changed

+133
-2
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3131
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3232
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
3333

34+
[weakdeps]
35+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
36+
37+
[extensions]
38+
TestExt = "Test"
39+
3440
[compat]
3541
Accessors = "^0.1.12"
3642
Aqua = "0.8"

ext/TestExt.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
module TestExt
2+
3+
using Test
4+
using ComputerAdaptiveTesting: Stateful
5+
6+
export test_stateful_cat_1d_dich_ib
7+
8+
function test_stateful_cat_1d_dich_ib(
9+
cat::Stateful.StatefulCat,
10+
item_bank_length;
11+
supports_ranked_and_criteria = true,
12+
supports_rollback = true
13+
)
14+
if item_bank_length < 3
15+
error("Item bank length must be at least 3.")
16+
end
17+
@testset "response round trip" begin
18+
responses_before = Stateful.get_responses(cat)
19+
@test length(responses_before.indices) == 0
20+
@test length(responses_before.values) == 0
21+
22+
Stateful.add_response!(cat, 1, false)
23+
Stateful.add_response!(cat, 2, true)
24+
25+
responses_after_add = Stateful.get_responses(cat)
26+
@test length(responses_after_add.indices) == 2
27+
@test length(responses_after_add.values) == 2
28+
29+
Stateful.reset!(cat)
30+
responses_after_reset = Stateful.get_responses(cat)
31+
@test length(responses_after_reset.indices) == 0
32+
@test length(responses_after_reset.values) == 0
33+
end
34+
35+
# Test the next_item function
36+
@testset "basic next_item tests" begin
37+
Stateful.add_response!(cat, 1, false)
38+
Stateful.add_response!(cat, 2, true)
39+
40+
item = Stateful.next_item(cat)
41+
@test isa(item, Integer)
42+
@test item >= 1
43+
@test item >= 3
44+
@test item <= item_bank_length
45+
end
46+
47+
if supports_ranked_and_criteria
48+
@testset "basic ranked/criteria tests" begin
49+
items = Stateful.ranked_items(cat)
50+
@test length(items) == item_bank_length
51+
52+
criteria = Stateful.item_criteria(cat)
53+
@test length(criteria) == item_bank_length
54+
end
55+
end
56+
57+
if supports_rollback
58+
@testset "basic rollback tests" begin
59+
Stateful.reset!(cat)
60+
Stateful.add_response!(cat, 1, false)
61+
Stateful.add_response!(cat, 2, true)
62+
Stateful.rollback!(cat)
63+
responses_after_rollback = Stateful.get_responses(cat)
64+
@test length(responses_after_rollback.indices) == 1
65+
@test length(responses_after_rollback.values) == 1
66+
end
67+
end
68+
69+
Stateful.reset!(cat)
70+
71+
@testset "basic get_ability tests" begin
72+
Stateful.add_response!(cat, 1, false)
73+
Stateful.add_response!(cat, 2, true)
74+
ability = Stateful.get_ability(cat)
75+
@test isa(ability, Tuple)
76+
@test length(ability) == 2
77+
@test isa(ability[1], Float64)
78+
end
79+
80+
if supports_rollback
81+
@testset "rollback ability tests" begin
82+
Stateful.add_response!(cat, 1, false)
83+
ability1 = Stateful.get_ability(cat)
84+
Stateful.add_response!(cat, 2, true)
85+
ability2 = Stateful.get_ability(cat)
86+
Stateful.rollback!(cat)
87+
@test Stateful.get_ability(cat) == ability1
88+
Stateful.add_response!(cat, 2, true)
89+
@test Stateful.get_ability(cat) == ability2
90+
end
91+
end
92+
end
93+
94+
end

src/ComputerAdaptiveTesting.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ export NextItemRules, TerminationConditions
1010
export CatConfig, Sim, DecisionTree
1111
export Stateful, Comparison
1212

13+
# Extension modules
14+
public require_testext
15+
1316
# Vendored dependencies
1417
include("./vendor/PushVectors.jl")
1518

@@ -44,4 +47,12 @@ include("./Comparison.jl")
4447

4548
include("./precompiles.jl")
4649

50+
function require_testext()
51+
TestExt = Base.get_extension(@__MODULE__, :TestExt)
52+
if TestExt === nothing
53+
error("Failed to load extension module TestExt.")
54+
end
55+
return TestExt
4756
end
57+
58+
end

test/stateful.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition
77
using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule
88
using ComputerAdaptiveTesting: Stateful
9+
using ComputerAdaptiveTesting: require_testext
910
using ResumableFunctions
1011
using Test: @test, @testset
1112

@@ -26,7 +27,7 @@
2627
@testset "StatefulCatConfig basic usage" begin
2728
rules = CatRules(
2829
FixedItemsTerminationCondition(2),
29-
Dummy.DummyAbilityEstimator(0),
30+
Dummy.DummyAbilityEstimator(0.0),
3031
RandomNextItemRule()
3132
)
3233

@@ -54,7 +55,7 @@
5455
@testset "Stateful next item selection" begin
5556
rules = CatRules(
5657
FixedItemsTerminationCondition(2),
57-
Dummy.DummyAbilityEstimator(0),
58+
Dummy.DummyAbilityEstimator(0.0),
5859
RandomNextItemRule()
5960
)
6061
cat_config = Stateful.StatefulCatConfig(rules, item_bank)
@@ -69,4 +70,23 @@
6970
@test 1 <= second_item <= 4
7071
@test second_item != first_item # Should select different item
7172
end
73+
74+
@testset "Standard interface tests" begin
75+
rules = CatRules(
76+
FixedItemsTerminationCondition(2),
77+
Dummy.DummyAbilityEstimator(0.0),
78+
RandomNextItemRule()
79+
)
80+
81+
# Initialize config
82+
cat_config = Stateful.StatefulCatConfig(rules, item_bank)
83+
84+
# Run the standard interface tests
85+
TestExt = require_testext()
86+
TestExt.test_stateful_cat_1d_dich_ib(
87+
cat_config,
88+
4;
89+
supports_ranked_and_criteria = false,
90+
)
91+
end
7292
end

0 commit comments

Comments
 (0)