Skip to content

Commit 1998aa0

Browse files
author
Frankie Robertson
committed
Improve Comparison including IncreaseItemBankSizeExecutionStrategy
1 parent a2a2409 commit 1998aa0

File tree

2 files changed

+99
-35
lines changed

2 files changed

+99
-35
lines changed

src/Comparison.jl

Lines changed: 87 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module Comparison
44
# Should be kept in mind and kept distinct or code reuse
55

66
using StatsBase
7-
using FittedItemBanks: AbstractItemBank, ResponseType
7+
using FittedItemBanks: AbstractItemBank, ResponseType, subset
88
using ..Responses
99
using ..CatConfig: CatLoopConfig, CatRules
1010
using ..Aggregators: TrackedResponses, add_response!, Speculator, Aggregators, track!,
@@ -14,11 +14,11 @@ using Base: Iterators
1414

1515
using HypothesisTests
1616
using EffectSizes
17-
using DataFrames
17+
using DataFrames: DataFrame
1818
using ComputerAdaptiveTesting: Stateful
1919

2020
export run_random_comparison, run_comparison
21-
export CatComparisonExecutionStrategy#, IncreaseItemBankSizeExecutionStrategy
21+
export CatComparisonExecutionStrategy, IncreaseItemBankSizeExecutionStrategy
2222
#export FollowOneExecutionStrategy, RunIndependentlyExecutionStrategy
2323
#export DecisionTreeExecutionStrategy
2424
export ReplayResponsesExecutionStrategy
@@ -83,7 +83,8 @@ end
8383

8484
abstract type CatComparisonExecutionStrategy end
8585

86-
Base.@kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrategy}
86+
struct CatComparisonConfig{
87+
StrategyT <: CatComparisonExecutionStrategy, PhasesT <: NamedTuple}
8788
"""
8889
A named tuple with the (named) CatRules (or compatable) to be compared
8990
"""
@@ -99,13 +100,42 @@ Base.@kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrate
99100
measurements::Vector{}
100101
=#
101102
"""
102-
Which phases to run and/or call the callback on
103+
The phases to run, optionally paired with a callback
103104
"""
104-
phases::Set{Symbol} = Set((:before_next_item, :after_next_item))
105-
"""
106-
The callback which should take a named tuple with information at different phases
107-
"""
108-
callback::Any
105+
phases::PhasesT
106+
end
107+
108+
"""
109+
CatComparisonConfig(;
110+
rules::NamedTuple{Symbol, StatefulCat},
111+
strategy::CatComparisonExecutionStrategy,
112+
phases::Union{NamedTuple{Symbol, Callable}, Tuple{Symbol}},
113+
callback::Callable
114+
) -> CatComparisonConfig
115+
116+
CatComparisonConfig sets up a evaluation-oriented comparison between different CAT systems.
117+
118+
Specify the comparison by listing: CAT systems in `rules`, a `NamedTuple` which gives
119+
identifiers to implementations of the `StatefulCat` interface; the `strategy` to use,
120+
an implementation of `CatComparisonExecutionStrategy`; the `phases` to run listed as
121+
either as a `NamedTuple` with names of phases and corresponding callbacks or `nothing` a
122+
`Tuple` of phases to run; and a `callback` which will be used as a fallback in cases where
123+
no callback is provided.
124+
125+
The exact phases depend on the strategy used. See their individual documentation for more.
126+
"""
127+
function CatComparisonConfig(; rules, strategy, phases = nothing, callback = nothing)
128+
if callback === nothing
129+
callback = (info; kwargs...) -> nothing
130+
end
131+
if phases === nothing
132+
phases = (:before_next_item, :after_next_item)
133+
end
134+
# TODO: normalize phases into named tuple
135+
if !(phases isa NamedTuple)
136+
phases = NamedTuple((phase => callback for phase in phases))
137+
end
138+
CatComparisonConfig(rules, strategy, phases)
109139
end
110140

111141
# Comparison scenarios:
@@ -129,9 +159,11 @@ end
129159

130160
#phase_func=nothing;
131161
function measure_all(comparison, system, cat, phase; kwargs...)
132-
if !(phase in comparison.phases)
162+
@info "measure_all" phase comparison.phases
163+
if !(phase in keys(comparison.phases))
133164
return
134165
end
166+
callback = comparison.phases[phase]
135167
strategy = comparison.strategy
136168
#=measurement_results = []
137169
for measurement in comparison.measurements
@@ -145,7 +177,7 @@ function measure_all(comparison, system, cat, phase; kwargs...)
145177
#end
146178
push!(measurement_results, result)
147179
end=#
148-
comparison.callback((;
180+
callback((;
149181
phase,
150182
system,
151183
cat,
@@ -158,30 +190,56 @@ struct IncreaseItemBankSizeExecutionStrategy <: CatComparisonExecutionStrategy
158190
item_bank::AbstractItemBank
159191
sizes::AbstractVector{Int}
160192
starting_responses::Int
193+
shuffle::Bool
194+
time_limit::Float64
195+
196+
function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, args...)
197+
if any((size > length(item_bank) for size in sizes))
198+
error("IncreaseItemBankSizeExecutionStrategy: No subset size can be greater than the number of items available in the item bank")
199+
end
200+
new(item_bank, sizes, args...)
201+
end
161202
end
162203

163204
function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes)
164-
return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0)
205+
return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0, false, Inf)
165206
end
166207

167-
function run_comparison(strategy::IncreaseItemBankSizeExecutionStrategy, config)
208+
function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExecutionStrategy})
209+
strategy = comparison.strategy
210+
current_cats = collect(pairs(comparison.rules))
211+
next_current_cats = copy(current_cats)
212+
@info "sizes" strategy.sizes
168213
for size in strategy.sizes
169-
subsetted_item_bank = subset(strategy.item_bank, size)
170-
responses = TrackedResponses(
171-
BareResponses(ResponseType(strategy.item_bank)),
172-
subsetted_item_bank,
173-
config.ability_tracker
174-
)
175-
for _ in 1:(strategy.starting_responses)
176-
next_item = config.next_item(responses, subsetted_item_bank)
177-
add_response!(responses,
178-
Response(ResponseType(subsetted_item_bank), next_item, rand(Bool)))
214+
subsetted_item_bank = subset(strategy.item_bank, 1:size)
215+
empty!(next_current_cats)
216+
for (name, cat) in current_cats
217+
Stateful.set_item_bank!(cat, subsetted_item_bank)
218+
for _ in 1:(strategy.starting_responses)
219+
Stateful.next_item(cat)
220+
end
221+
measure_all(
222+
comparison,
223+
name,
224+
cat,
225+
:before_next_item
226+
)
227+
timed_next_item = @timed Stateful.next_item(cat)
228+
next_item = timed_next_item.value
229+
measure_all(
230+
comparison,
231+
name,
232+
cat,
233+
:after_next_item,
234+
next_item = next_item,
235+
timing = timed_next_item
236+
)
237+
@info "next_item" timed_next_item.time strategy.time_limit
238+
if timed_next_item.time < strategy.time_limit
239+
push!(next_current_cats, name => cat)
240+
end
179241
end
180-
measure_all(config, :before_next_item, before_next_item; responses = responses)
181-
timed_next_item = @timed config.next_item(responses, item_bank)
182-
next_item = timed_next_item.value
183-
measure_all(config, :after_next_item, after_next_item;
184-
responses = responses, next_item = next_item)
242+
current_cats, next_current_cats = next_current_cats, current_cats
185243
end
186244
end
187245

src/Stateful.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
struct StatefulCatConfig{ItemBankT <: AbstractItemBank} <: StatefulCat
6060
rules::CatRules
6161
tracked_responses::TrackedResponses
62-
item_bank::ItemBankT
62+
item_bank::Ref{ItemBankT}
6363
end
6464

6565
function StatefulCatConfig(rules, item_bank)
@@ -69,26 +69,27 @@ function StatefulCatConfig(rules, item_bank)
6969
item_bank,
7070
rules.ability_tracker
7171
)
72-
return StatefulCatConfig(rules, tracked_responses, item_bank)
72+
return StatefulCatConfig(rules, tracked_responses, Ref(item_bank))
7373
end
7474

7575
function next_item(config::StatefulCatConfig)
76-
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank)
76+
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank[])
7777
end
7878

7979
function ranked_items(config::StatefulCatConfig)
8080
return sortperm(compute_criteria(
81-
config.rules.next_item, config.tracked_responses, config.item_bank))
81+
config.rules.next_item, config.tracked_responses, config.item_bank[]))
8282
end
8383

8484
function item_criteria(config::StatefulCatConfig)
8585
return compute_criteria(
86-
config.rules.next_item, config.tracked_responses, config.item_bank)
86+
config.rules.next_item, config.tracked_responses, config.item_bank[])
8787
end
8888

8989
function add_response!(config::StatefulCatConfig, index, response)
9090
Aggregators.add_response!(
91-
config.tracked_responses, Response(ResponseType(config.item_bank), index, response))
91+
config.tracked_responses, Response(
92+
ResponseType(config.item_bank[]), index, response))
9293
end
9394

9495
function rollback!(config::StatefulCatConfig)
@@ -99,6 +100,11 @@ function reset!(config::StatefulCatConfig)
99100
empty!(config.tracked_responses)
100101
end
101102

103+
function set_item_bank!(config::StatefulCatConfig, item_bank)
104+
reset!(config)
105+
config.item_bank[] = item_bank
106+
end
107+
102108
function get_responses(config::StatefulCatConfig)
103109
return config.tracked_responses.responses
104110
end

0 commit comments

Comments
 (0)