From a419d90cabb9d0468fb8637c077b38b0495a8e7c Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 20 Apr 2025 10:39:25 +0300 Subject: [PATCH 01/42] Update gitignore --- .gitignore | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 8ff79ba..5a2f3d5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,12 +2,9 @@ *.jl.cov *.jl.mem Manifest.toml -!/docs/Manifest.toml -!/test/Manifest.toml !/binder/Manifest.toml -/attic/ +attic/ /.vscode/ -/docs/attic/ /docs/build/ /docs/.CondaPkg/ /docs/LocalPreferences.toml From 920cd88829215f8e9efacecdba4bfb04cc01f899 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 20 Apr 2025 10:41:21 +0300 Subject: [PATCH 02/42] item_response_function => item_response_functions --- src/Stateful.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/Stateful.jl b/src/Stateful.jl index eb7d535..a9eaf74 100644 --- a/src/Stateful.jl +++ b/src/Stateful.jl @@ -7,7 +7,7 @@ module Stateful using DocStringExtensions -using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, resp +using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, resp_vec using ..Aggregators: TrackedResponses, Aggregators using ..CatConfig: CatLoopConfig, CatRules using ..Responses: BareResponses, Response, Responses @@ -135,13 +135,14 @@ function item_bank_size end """ ```julia -$(FUNCTIONNAME)(config::StatefulCat, index::IndexT, response::ResponseT, ability::AbilityT) -> Float +$(FUNCTIONNAME)(config::StatefulCat, index::IndexT, ability::AbilityT) -> AbstractVector{Float} ```` -Return the probability of a `response` to item at `index` for someone with -a certain `ability` according to the IRT model backing the CAT. +Return the vector of probability of different responses to item at +`index` for someone with a certain `ability` according to the IRT +model backing the CAT. """ -function item_response_function end +function item_response_functions end ## Running the CAT function Sim.run_cat(cat_config::CatLoopConfig{RulesT}, @@ -243,10 +244,10 @@ function item_bank_size(config::StatefulCatConfig) return length(config.tracked_responses[].item_bank) end -function item_response_function(config::StatefulCatConfig, index, response, ability) +function item_response_functions(config::StatefulCatConfig, index, ability) item_bank = config.tracked_responses[].item_bank item_response = ItemResponse(item_bank, index) - return resp(item_response, response, ability) + return resp_vec(item_response, ability) end ## TODO: Implementation for MaterializedDecisionTree From c3f4eb8d02b0c1a0e63b7e441d507ac89251adc3 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 30 May 2025 00:31:27 +0300 Subject: [PATCH 03/42] Add sequential next item strategies --- src/next_item_rules/NextItemRules.jl | 2 + src/next_item_rules/strategies/sequential.jl | 51 ++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 src/next_item_rules/strategies/sequential.jl diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index 43d9c4e..48565c9 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -39,6 +39,7 @@ export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread export NextItemRule, ItemStrategyNextItemRule export UrryItemCriterion, InformationItemCriterion export RandomNextItemRule +export PiecewiseNextItemRule, MemoryNextItemRule, FixedFirstItemNextItemRule export ExhaustiveSearch export catr_next_item_aliases export preallocate @@ -60,6 +61,7 @@ include("./prelude/preallocate.jl") # Selection strategies include("./strategies/random.jl") +include("./strategies/sequential.jl") include("./strategies/exhaustive.jl") # Combinators diff --git a/src/next_item_rules/strategies/sequential.jl b/src/next_item_rules/strategies/sequential.jl new file mode 100644 index 0000000..12676b2 --- /dev/null +++ b/src/next_item_rules/strategies/sequential.jl @@ -0,0 +1,51 @@ +""" +$(TYPEDEF) +$(TYPEDFIELDS) + +This is the most basic rule for choosing the next item in a CAT. It simply +picks a random item from the set of items that have not yet been +administered. +""" +@kwdef struct PiecewiseNextItemRule{BreaksT, RulesT} <: NextItemRule + # Tuple of Ints + breaks::BreaksT + # Types of NextItemRules + rules::RulesT +end + +#tuple_len(::NTuple{N, Any}) where {N} = Val{N}() + +function current_rule(rule::PiecewiseNextItemRule, responses::TrackedResponses) + for brk in 1:length(rule.breaks) + if length(responses) < rule.breaks[brk] + return rule.rules[brk] + end + end + return rule.rules[end] +end + +function best_item(rule::PiecewiseNextItemRule, responses::TrackedResponses, items) + return best_item(current_rule(rule, responses), responses, items) +end + +function compute_criteria(rule::PiecewiseNextItemRule, responses::TrackedResponses) + return compute_criteria(current_rule(rule, responses), responses) +end + +""" +""" +@kwdef struct MemoryNextItemRule{MemoryT} <: NextItemRule + item_idxs::MemoryT +end + +function best_item(rule::MemoryNextItemRule, responses::TrackedResponses, _items) + return rule.item_idxs[length(responses) + 1] + # XXX: A few problems with this: + # 1. Could run out of `item_idxs` + # 2. Could return an item not in `items` + # TODO: Add some basic error checking -- can only panic +end + +function FixedFirstItemNextItemRule(item_idx::Int, rule::NextItemRule) + PiecewiseNextItemRule((1,), (MemoryNextItemRule((item_idx,)), rule)) +end \ No newline at end of file From f7d3b744efa1f95d182bd1ce875f343819968b31 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 30 May 2025 00:39:00 +0300 Subject: [PATCH 04/42] Adjust test_stateful_cat_item_bank_1d_dich_ib to test whole vec --- ext/TestExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/TestExt.jl b/ext/TestExt.jl index c305dfc..1448888 100644 --- a/ext/TestExt.jl +++ b/ext/TestExt.jl @@ -2,7 +2,7 @@ module TestExt using Test using ComputerAdaptiveTesting: Stateful -using FittedItemBanks: AbstractItemBank, ItemResponse, resp +using FittedItemBanks: AbstractItemBank, ItemResponse, resp_vec export test_stateful_cat_1d_dich_ib, test_stateful_cat_item_bank_1d_dich_ib @@ -103,8 +103,8 @@ function test_stateful_cat_item_bank_1d_dich_ib( end for i in 1:length(item_bank) for point in points - cat_prob = Stateful.item_response_function(cat, i, true, point) - ib_prob = resp(ItemResponse(item_bank, i), true, point) + cat_prob = Stateful.item_response_functions(cat, i, point) + ib_prob = resp_vec(ItemResponse(item_bank, i), point) @test cat_prob ≈ ib_prob rtol=margin end end From e04a19f47f54d60c0950456c7144ee3a47748682 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 30 May 2025 00:39:59 +0300 Subject: [PATCH 05/42] Update compats --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 24ec8c9..12276e5 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ Distributions = "^0.25.88" DocStringExtensions = " ^0.9" EffectSizes = "^1.0.1" FillArrays = "0.13, 1.5.0" -FittedItemBanks = "^0.6.3" +FittedItemBanks = "^0.6.3, ^0.7.0" ForwardDiff = "1" HypothesisTests = "^0.10.12, ^0.11.0" Interpolations = "^0.14, ^0.15" @@ -62,7 +62,7 @@ PrecompileTools = "1.2.1" PsychometricsBazaarBase = "^0.8.1" Random = "^1.11" Reexport = "1" -ResumableFunctions = "^0.6" +ResumableFunctions = "^0.6, 1" Setfield = "^1" SparseArrays = "^1.11" StaticArrays = "1" From dac0c6bef20580ee5b64b1e1705e58382d5d9622 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 30 May 2025 00:53:45 +0300 Subject: [PATCH 06/42] Add watchdog, likelihood sampling, callback skipping to comparison --- src/{ => comparison}/Comparison.jl | 259 +++++++++++++++++++---------- src/comparison/watchdog.jl | 154 +++++++++++++++++ 2 files changed, 321 insertions(+), 92 deletions(-) rename src/{ => comparison}/Comparison.jl (69%) create mode 100644 src/comparison/watchdog.jl diff --git a/src/Comparison.jl b/src/comparison/Comparison.jl similarity index 69% rename from src/Comparison.jl rename to src/comparison/Comparison.jl index 40aa9e2..046bbcb 100644 --- a/src/Comparison.jl +++ b/src/comparison/Comparison.jl @@ -23,6 +23,8 @@ export CatComparisonExecutionStrategy, IncreaseItemBankSizeExecutionStrategy export ReplayResponsesExecutionStrategy export CatComparisonConfig +include("./watchdog.jl") + struct RandomCatComparison true_abilities::Array{Float64} rand_abilities::Array{Float64, 3} @@ -82,8 +84,7 @@ end abstract type CatComparisonExecutionStrategy end -struct CatComparisonConfig{ - StrategyT <: CatComparisonExecutionStrategy, PhasesT <: NamedTuple} +struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrategy, PhasesT <: NamedTuple} """ A named tuple with the (named) CatRules (or compatable) to be compared """ @@ -102,6 +103,18 @@ struct CatComparisonConfig{ The phases to run, optionally paired with a callback """ phases::PhasesT + """ + Where to sample for likelihood + """ + sample_points::Union{Vector{Float64}, Nothing} + """ + Skips + """ + skip_callback + """ + Watchdog timeout + """ + timeout::Float64 end """ @@ -109,6 +122,7 @@ end rules::NamedTuple{Symbol, StatefulCat}, strategy::CatComparisonExecutionStrategy, phases::Union{NamedTuple{Symbol, Callable}, Tuple{Symbol}}, + skips::Set{Tuple{Symbol, Symbol}}, callback::Callable ) -> CatComparisonConfig @@ -123,18 +137,24 @@ no callback is provided. The exact phases depend on the strategy used. See their individual documentation for more. """ -function CatComparisonConfig(; rules, strategy, phases = nothing, callback = nothing) +function CatComparisonConfig(; rules, strategy, phases = nothing, skip_callback = ((_, _, _) -> false), sample_points = nothing, callback = nothing, timeout = Inf) if callback === nothing callback = (info; kwargs...) -> nothing end if phases === nothing phases = (:before_next_item, :after_next_item) end - # TODO: normalize phases into named tuple if !(phases isa NamedTuple) phases = NamedTuple((phase => callback for phase in phases)) end - CatComparisonConfig(rules, strategy, phases) + CatComparisonConfig( + rules, + strategy, + phases, + sample_points, + skip_callback, + timeout + ) end # Comparison scenarios: @@ -158,7 +178,6 @@ end #phase_func=nothing; function measure_all(comparison, system, cat, phase; kwargs...) - @info "measure_all" phase system kwargs if !(phase in keys(comparison.phases)) return end @@ -273,7 +292,6 @@ function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExec num_items=size, system_name=name ) - @info "next_item" name timed_next_item.time strategy.time_limit if timed_next_item.time < strategy.time_limit push!(next_current_cats, name => cat) end @@ -300,108 +318,165 @@ end struct ReplayResponsesExecutionStrategy <: CatComparisonExecutionStrategy responses::BareResponses + time_limit::Float64 +end + +ReplayResponsesExecutionStrategy(responses) = ReplayResponsesExecutionStrategy(responses, Inf) + +function should_run(comparison, name, cat, phase) + return phase in keys(comparison.phases) && + !comparison.skip_callback(name, cat, phase) end # Which questions to ask: Specified # Which answer to use: From response memory function run_comparison(comparison::CatComparisonConfig{ReplayResponsesExecutionStrategy}) strategy = comparison.strategy - for (items_answered, response) in zip( - Iterators.countfrom(0), Iterators.flatten((strategy.responses, [nothing]))) - for (name, cat) in pairs(comparison.rules) - if :before_item_criteria in comparison.phases - timed_item_criteria = @timed Stateful.item_criteria(cat) - measure_all( - comparison, - name, - cat, - :before_item_criteria, - items_answered = items_answered, - item_criteria = timed_item_criteria.value, - timing = timed_item_criteria - ) - end - if :before_ranked_items in comparison.phases - timed_ranked_items = @timed Stateful.ranked_items(cat) - measure_all( - comparison, - name, - cat, - :before_ranked_items, - items_answered = items_answered, - ranked_items = timed_ranked_items.value, - timing = timed_ranked_items - ) - end - if :before_ability in comparison.phases - timed_get_ability = @timed Stateful.get_ability(cat) - measure_all( - comparison, - name, - cat, - :before_ability, - items_answered = items_answered, - ability = timed_get_ability.value, - timing = timed_get_ability - ) - end - measure_all( - comparison, - name, - cat, - :before_next_item, - items_answered = items_answered - ) - timed_next_item = @timed Stateful.next_item(cat) - next_item = timed_next_item.value - measure_all( - comparison, - name, - cat, - :after_next_item, - next_item = next_item, - timing = timed_next_item, - items_answered = items_answered - ) - if :after_item_criteria in comparison.phases - # TOOD: Combine with next_item if possible and requested? - timed_item_criteria = @timed Stateful.item_criteria(cat) - measure_all( - comparison, - name, - cat, - :after_item_criteria, - items_answered = items_answered, - item_criteria = timed_item_criteria.value, - timing = timed_item_criteria - ) + current_cats = Dict(pairs(comparison.rules)) + function check_time(name, timer) + if timer.time >= strategy.time_limit + if name in keys(current_cats) + @info "Time limit exceeded" name timer.time + delete!(current_cats, name) end - if :after_ranked_items in comparison.phases - timed_ranked_items = @timed Stateful.ranked_items(cat) + end + end + watchdog = WatchdogTask(comparison.timeout) + start!(watchdog) do + for (items_answered, response) in zip( + Iterators.countfrom(0), Iterators.flatten((strategy.responses, [nothing]))) + for (name, cat) in pairs(current_cats) + println("") + println("Starting $name for $items_answered") + flush(stdout) + if should_run(comparison, name, cat, :before_item_criteria) + reset!(watchdog, "$name item_criteria") + timed_item_criteria = @timed Stateful.item_criteria(cat) + check_time(name, timed_item_criteria) + measure_all( + comparison, + name, + cat, + :before_item_criteria, + items_answered = items_answered, + item_criteria = timed_item_criteria.value, + timing = timed_item_criteria + ) + end + if should_run(comparison, name, cat, :before_ranked_items) + reset!(watchdog, "$name ranked_items") + timed_ranked_items = @timed Stateful.ranked_items(cat) + check_time(name, timed_ranked_items) + measure_all( + comparison, + name, + cat, + :before_ranked_items, + items_answered = items_answered, + ranked_items = timed_ranked_items.value, + timing = timed_ranked_items + ) + end + if should_run(comparison, name, cat, :before_ability) + reset!(watchdog, "$name get_ability") + timed_get_ability = @timed Stateful.get_ability(cat) + check_time(name, timed_get_ability) + measure_all( + comparison, + name, + cat, + :before_ability, + items_answered = items_answered, + ability = timed_get_ability.value, + timing = timed_get_ability + ) + end measure_all( comparison, name, cat, - :after_ranked_items, - items_answered = items_answered, - ranked_items = timed_ranked_items.value, - timing = timed_ranked_items + :before_next_item, + items_answered = items_answered ) - end - if :after_ability in comparison.phases - timed_get_ability = @timed Stateful.get_ability(cat) + reset!(watchdog, "$name next_item") + timed_next_item = @timed Stateful.next_item(cat) + check_time(name, timed_next_item) + next_item = timed_next_item.value measure_all( comparison, name, cat, - :after_ability, - items_answered = items_answered, - ability = timed_get_ability.value, - timing = timed_get_ability + :after_next_item, + next_item = next_item, + timing = timed_next_item, + items_answered = items_answered ) - end - if response !== nothing - Stateful.add_response!(cat, response.index, response.value) + if should_run(comparison, name, cat, :after_item_criteria) + # TOOD: Combine with next_item if possible and requested? + reset!(watchdog, "$name item_criteria") + timed_item_criteria = @timed Stateful.item_criteria(cat) + check_time(name, timed_item_criteria) + if timed_item_criteria.value !== nothing + measure_all( + comparison, + name, + cat, + :after_item_criteria, + items_answered = items_answered, + item_criteria = timed_item_criteria.value, + timing = timed_item_criteria + ) + end + end + if should_run(comparison, name, cat, :after_ranked_items) + reset!(watchdog, "$name ranked_items") + timed_ranked_items = @timed Stateful.ranked_items(cat) + check_time(name, timed_ranked_items) + if timed_ranked_items.value !== nothing + measure_all( + comparison, + name, + cat, + :after_ranked_items, + items_answered = items_answered, + ranked_items = timed_ranked_items.value, + timing = timed_ranked_items + ) + end + end + if should_run(comparison, name, cat, :after_likelihood) + reset!(watchdog, "$name likelihood") + timed_likelihood = @timed Stateful.likelihood.(Ref(cat), comparison.sample_points) + check_time(name, timed_likelihood) + measure_all( + comparison, + name, + cat, + :after_likelihood, + items_answered = items_answered, + sample_points = comparison.sample_points, + likelihood = timed_likelihood.value, + timing = timed_likelihood + ) + + end + if should_run(comparison, name, cat, :after_ability) + reset!(watchdog, "$name get_ability") + timed_get_ability = @timed Stateful.get_ability(cat) + check_time(name, timed_get_ability) + measure_all( + comparison, + name, + cat, + :after_ability, + items_answered = items_answered, + ability = timed_get_ability.value, + timing = timed_get_ability + ) + end + if response !== nothing + Stateful.add_response!(cat, response.index, response.value) + end end end end diff --git a/src/comparison/watchdog.jl b/src/comparison/watchdog.jl new file mode 100644 index 0000000..7a4fda8 --- /dev/null +++ b/src/comparison/watchdog.jl @@ -0,0 +1,154 @@ +using Base.Threads: nthreads + + +abstract type AbstractWatchdogTask end + +mutable struct WatchdogTask <: AbstractWatchdogTask + timeout::Float64 + channel::Channel + task::Union{Task, Nothing} +end + +function run_watchdog(timeout, channel, worker_task) + #Core.println("Starting watchdog") + #Base.flush(stdout) + reset_timestamp = time() + deadline = reset_timestamp + timeout + msg = nothing + active = false + die = false + #Core.println("X") + #Base.flush(stdout) + l = ReentrantLock() + #Core.println("Y") + #Base.flush(stdout) + activation = Threads.Condition(l) + #Core.println("Blam") + #Base.flush(stdout) + @async begin + #Core.println("Subloop") + while true + cmd = take!(channel) + #Core.println("Take") + if haskey(cmd, :kill) + die = true + lock(l) do + notify(activation) + end + break + end + lock(l) do + if haskey(cmd, :reset_timestamp) + reset_timestamp = cmd[:reset_timestamp] + deadline = reset_timestamp + timeout + end + if haskey(cmd, :msg) + msg = cmd[:msg] + end + if haskey(cmd, :active) + active = cmd[:active] + if active + #Core.println("Notify") + notify(activation) + end + end + end + end + end + loop = true + while loop + #Core.println("Aquiring lock") + loop = lock(l) do + while !active && !die + #Core.println("Waiting for activation") + wait(activation) + end + if die + return false + end + if active + unlock(l) + try + delay = deadline - time() + #Core.println("Sleeping for $delay") + sleep(max(delay, 0.0)) + finally + lock(l) + end + end + if die + return false + end + overrun = time() - deadline + if overrun > 0 && active + msg = "WATCHDOG TIMEOUT: $msg timed after after $(timeout)s (overran $(overrun)s)" + unlock(l) + put!(channel, (; kill=true)) + Core.println("") + Core.println(msg) + Core.println("") + Base.flush(Core.stdout) + sleep(0.1) + schedule(worker_task, InterruptException(), error=true) + # Wait a proper amount of time here since otherwise we will probably not get a stacktrace + sleep(5.0) + if istaskdone(worker_task) + return false + end + ccall(:uv_kill, Cint, (Cint, Cint), getpid(), Base.SIGTERM) + sleep(1.0) + ccall(:uv_kill, Cint, (Cint, Cint), getpid(), Base.SIGKILL) + sleep(1.0) + exit(1) # This is done last since it doesn't always take down the parent + end + return true + end + end +end + +function WatchdogTask(timeout::Float64) + if timeout !== Inf + channel = Channel{Any}(Inf) + WatchdogTask(timeout, channel, nothing) + else + NullWatchdog() + end + #WatchdogTask(task, timeout, channel, nothing) +end + +function start!(f, watchdog::WatchdogTask) + if nthreads(:interactive) < 1 || nthreads(:default) < 1 + error("WatchdogTask: Need an interactive and default thread") + end + worker_task = Threads.@spawn :default f() + watchdog.task = Threads.@spawn :interactive run_watchdog(watchdog.timeout, watchdog.channel, worker_task) + wait(worker_task) + put!(watchdog.channel, (; kill=true)) + wait(watchdog.task) +end + +function reset!(watchdog::WatchdogTask, msg=nothing) + if istaskdone(watchdog.task) + wait(watchdog.task) + end + payload = (; + active=true, + reset_timestamp=time(), + ) + if msg !== nothing + payload = (; payload..., msg=msg) + end + #@info "Put" payload + put!(watchdog.channel, payload) +end + +function deactivate!(watchdog::WatchdogTask) + if istaskdone(watchdog.task) + wait(watchdog.task) + end + put(watchdog.channel, (; active=false)) +end + +struct NullWatchdog <: AbstractWatchdogTask end +function reset!(::NullWatchdog, msg=nothing) end +function deactivate!(::NullWatchdog) end \ No newline at end of file From 4526bba5ec6ad64a03a5e0322abdc3fb58219eff Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 30 May 2025 00:59:03 +0300 Subject: [PATCH 07/42] Add compat layer --- src/Compat/CatR.jl | 111 ++++++++++++++++++++++++++ src/Compat/Compat.jl | 6 ++ src/Compat/MirtCAT.jl | 141 +++++++++++++++++++++++++++++++++ src/ComputerAdaptiveTesting.jl | 5 +- test/compat.jl | 78 ++++++++++++++++++ test/runtests.jl | 1 + 6 files changed, 340 insertions(+), 2 deletions(-) create mode 100644 src/Compat/CatR.jl create mode 100644 src/Compat/Compat.jl create mode 100644 src/Compat/MirtCAT.jl create mode 100644 test/compat.jl diff --git a/src/Compat/CatR.jl b/src/Compat/CatR.jl new file mode 100644 index 0000000..1314627 --- /dev/null +++ b/src/Compat/CatR.jl @@ -0,0 +1,111 @@ +module CatR + +using ComputerAdaptiveTesting.Aggregators: AbilityIntegrator, + LikelihoodAbilityEstimator, + DistributionAbilityEstimator, + ModeAbilityEstimator, + MeanAbilityEstimator, + PriorAbilityEstimator +using ComputerAdaptiveTesting.TerminationConditions: RunForeverTerminationCondition +using ComputerAdaptiveTesting.CatConfig: CatRules +using ComputerAdaptiveTesting.NextItemRules +using PsychometricsBazaarBase: Integrators, Optimizers + +public next_item_aliases, ability_estimator_aliases, assemble_rules + +function _next_item_aliases() + res = Dict{String, Any}() + for (nick, mk_item_criterion) in ( + "MFI" => InformationItemCriterion, + "bOpt" => UrryItemCriterion, + ) + res[nick] = (bits...; kwargs...) -> ItemStrategyNextItemRule( + ExhaustiveSearch(), + mk_item_criterion(bits...)) + end + res["MEPV"] = (bits...; posterior_ability_estimator, kwargs...) -> ItemStrategyNextItemRule( + ExhaustiveSearch(), + ExpectationBasedItemCriterion(bits..., + AbilityVarianceStateCriterion(posterior_ability_estimator, AbilityIntegrator(bits...)))) + res["MEI"] = (bits...; kwargs...) -> ItemStrategyNextItemRule( + ExhaustiveSearch(), + ExpectationBasedItemCriterion(bits..., + InformationItemCriterion(bits...))) + #"MLWI", #"MPWI", + return res + #"thOpt", + #"progressive", + #"proportional", + #"KL", + #"KLP", + #"GDI", + #"GDIP", + #"random" +end + +""" +This mapping provides next item rules through the same names that they are +available through in the `catR` R package. TODO compability with `mirtcat` +""" +const next_item_aliases = _next_item_aliases() + +function _ability_estimator_aliases() + res = Dict{String, Any}() + res["BM"] = (; optimizer, kwargs...) -> ModeAbilityEstimator(PriorAbilityEstimator(), optimizer) + res["ML"] = (; optimizer, kwargs...) -> ModeAbilityEstimator(LikelihoodAbilityEstimator(), optimizer) + res["EAP"] = (; integrator, kwargs...) -> MeanAbilityEstimator(PriorAbilityEstimator(), integrator) + #res["WL"] + #res["ROB"] + return res +end + +const ability_estimator_aliases = _ability_estimator_aliases() + +#= + for (resp_exp, resp_exp_nick) in resp_exp_nick_pairs + next_item_rule = NextItemRule( + ExpectationBasedItemCriterion(resp_exp, AbilityVarianceStateCriterion(numtools.integrator, distribution_estimator(abil_est))) + ) + next_item_rule = preallocate(next_item_rule) + est_next_item_rule_pairs[Symbol("$(abil_est_str)_mepv_$(resp_exp_nick)")] = (abil_est, next_item_rule) + next_item_rule = NextItemRule( + ExpectationBasedItemCriterion(resp_exp, InformationItemCriterion(abil_est)) + ) + next_item_rule = preallocate(next_item_rule) + est_next_item_rule_pairs[Symbol("$(abil_est_str)_mei_$(resp_exp_nick)")] = (abil_est, next_item_rule) + end + est_next_item_rule_pairs[Symbol("$(abil_est_str)_mi")] = (abil_est, InformationItemCriterion(abil_est)) +=# + + +function setup_integrator(lo=-4.0, hi=4.0, pts=33) + Integrators.MidpointIntegrator(range(lo, hi, pts)) +end + +function setup_optimizer(lo=-4.0, hi=4.0) + Optimizers.NativeOneDimOptimOptimizer(; lo, hi) +end + +function assemble_rules(; + criterion, + method, + start_item = 1 + #prior_dist="norm", + #prior_par=@SVector[0.0, 1.0], + #info_type="observed" +) + integrator = setup_integrator() + optimizer = setup_optimizer() + ability_estimator = ability_estimator_aliases[method](; integrator, optimizer) + posterior_ability_estimator = PriorAbilityEstimator() + raw_next_item = next_item_aliases[criterion](ability_estimator, integrator, optimizer; posterior_ability_estimator=posterior_ability_estimator) + next_item = FixedFirstItemNextItemRule(start_item, raw_next_item) + CatRules(; + next_item, + termination_condition = RunForeverTerminationCondition(), + ability_estimator, + #ability_tracker::AbilityTrackerT = NullAbilityTracker() + ) +end + +end \ No newline at end of file diff --git a/src/Compat/Compat.jl b/src/Compat/Compat.jl new file mode 100644 index 0000000..a8a29ab --- /dev/null +++ b/src/Compat/Compat.jl @@ -0,0 +1,6 @@ +module Compat + +include("./CatR.jl") +include("./MirtCAT.jl") + +end \ No newline at end of file diff --git a/src/Compat/MirtCAT.jl b/src/Compat/MirtCAT.jl new file mode 100644 index 0000000..7d5a212 --- /dev/null +++ b/src/Compat/MirtCAT.jl @@ -0,0 +1,141 @@ +module MirtCAT + +using ComputerAdaptiveTesting.Aggregators: SafeLikelihoodAbilityEstimator, + LikelihoodAbilityEstimator, + DistributionAbilityEstimator, + ModeAbilityEstimator, + MeanAbilityEstimator, + PriorAbilityEstimator, + AbilityEstimator, + distribution_estimator +using ComputerAdaptiveTesting.TerminationConditions: RunForeverTerminationCondition +using ComputerAdaptiveTesting.CatConfig: CatRules +using ComputerAdaptiveTesting.NextItemRules +using PsychometricsBazaarBase: Integrators, Optimizers + +public next_item_aliases, ability_estimator_aliases, assemble_rules + +function _next_item_helper(item_criterion_callback) + function _helper(ability_estimator, posterior_ability_estimator, integrator, optimizer) + bits = [ + ability_estimator, + integrator, + optimizer, + ] + item_criterion = item_criterion_callback(; bits, ability_estimator, posterior_ability_estimator, integrator, optimizer) + return ItemStrategyNextItemRule(ExhaustiveSearch(), item_criterion) + end + return _helper +end + +const next_item_aliases = Dict( + # "MI' for the maximum information + "MI" => _next_item_helper((; bits, ability_estimator, rest...) -> InformationItemCriterion(ability_estimator)), + # 'MEPV' for minimum expected posterior variance + "MEPV" => _next_item_helper((; bits, ability_estimator, posterior_ability_estimator, integrator, rest...) -> ExpectationBasedItemCriterion( + ability_estimator, + AbilityVarianceStateCriterion(posterior_ability_estimator, integrator))), + "MEI" => _next_item_helper((; bits, ability_estimator, rest...) -> ExpectationBasedItemCriterion( + ability_estimator, + PointItemCategoryCriterion(EmpiricalInformationPointwiseItemCategoryCriterion(), ability_estimator) + )), + "MLWI" => _next_item_helper((; bits, ability_estimator, integrator, rest...) -> LikelihoodWeightedItemCriterion( + TotalItemInformation(RawEmpiricalInformationPointwiseItemCategoryCriterion()), + distribution_estimator(ability_estimator), + integrator + )), + "MPWI" => _next_item_helper((; bits, ability_estimator, posterior_ability_estimator, integrator, rest...) -> LikelihoodWeightedItemCriterion( + TotalItemInformation(RawEmpiricalInformationPointwiseItemCategoryCriterion()), + distribution_estimator(posterior_ability_estimator), + integrator + )), + "Drule" => _next_item_helper((; bits, ability_estimator, rest...) -> DRuleItemCriteron(ability_estimator)), + "Trule" => _next_item_helper((; bits, ability_estimator, rest...) -> TRuleItemCriteron(ability_estimator)) +) + +# 'IKLP' as well as 'IKL' for the integration based Kullback-Leibler criteria with and without the prior density weight, +# respectively, and their root-n items administered weighted counter-parts, 'IKLn' and 'IKLPn'. +#= +Possible inputs for multidimensional adaptive tests include: 'Drule' for the +maximum determinant of the information matrix, 'Trule' for the maximum +(potentially weighted) trace of the information matrix, 'Arule' for the minimum (potentially weighted) trace of the asymptotic covariance matrix, 'Erule' +for the minimum value of the information matrix, and 'Wrule' for the weighted +information criteria. For each of these rules, the posterior weight for the latent trait scores can also be included with the 'DPrule', 'TPrule', 'APrule', +'EPrule', and 'WPrule', respectively. +Applicable to both unidimensional and multidimensional tests are the 'KL' and +'KLn' for point-wise Kullback-Leibler divergence and point-wise KullbackLeibler with a decreasing delta value (delta*sqrt(n), where n is the number +of items previous answered), respectively. The delta criteria is defined in the +design object +Non-adaptive methods applicable even when no mo object is passed are: 'random' +to randomly select items, and 'seq' for selecting items sequentially +=# + +const ability_estimator_aliases = Dict( + "MAP" => (; optimizer, kwargs...) -> ModeAbilityEstimator(PriorAbilityEstimator(), optimizer), + "ML" => (; optimizer, kwargs...) -> ModeAbilityEstimator(SafeLikelihoodAbilityEstimator(), optimizer), + "EAP" => (; integrator, kwargs...) -> MeanAbilityEstimator(PriorAbilityEstimator(), integrator), +# "WLE" for weighted likelihood estimation +# "EAPsum" for the expected a-posteriori for each sum score +) + +#= +• "plausible" for a single plausible value imputation for each case. This is +equivalent to setting plausible.draws = 1 +• "classify" for the posteriori classification probabilities (only applicable +when the input model was of class MixtureClass) +=# + +function mirtcat_quadpts(nfact) + if nfact == 1 + return 121 + elseif nfact == 2 + return 61 + elseif nfact == 3 + return 31 + elseif nfact == 4 + return 19 + elseif nfact == 5 + return 11 + else + return 5 + end +end + +function setup_integrator(lo=-6.0, hi=6.0, pts=mirtcat_quadpts(1)) + Integrators.even_grid(lo, hi, pts) +end + +function setup_optimizer(lo=-6.0, hi=6.0) + # TODO: Is this correct? + # mirtcat uses the `nlm` function from the `stats` package + # Source: https://github.com/philchalmers/mirt/blob/46b5db3a0120d87b8f1b034e6111fc5fb352a698/R/fscores.internal.R#L957C25-L957C28 + # It looks like no gradient is passed, so the numerical gradient will be used + # Source: https://github.com/philchalmers/mirt/blob/46b5db3a0120d87b8f1b034e6111fc5fb352a698/R/fscores.internal.R#L623 + # This is what we get by default so do this + # Main difference is probably in the line search + # https://stats.stackexchange.com/questions/272880/algorithm-used-in-nlm-function-in-r + # So just use Newton() with defaults for now + # Except then we can't have box constraints so I suppose IPNewton + Optimizers.OneDimOptimOptimizer(lo, hi, Optimizers.IPNewton()) +end + +function assemble_rules(; + criteria = "MI", + method = "MAP", + start_item = 1 +) + integrator = setup_integrator() + optimizer = setup_optimizer() + ability_estimator = ability_estimator_aliases[method](; integrator, optimizer) + posterior_ability_estimator = PriorAbilityEstimator() + @info "assemble rules" criteria + raw_next_item = next_item_aliases[criteria](ability_estimator, posterior_ability_estimator, integrator, optimizer) + next_item = FixedFirstItemNextItemRule(start_item, raw_next_item) + CatRules(; + next_item, + ability_estimator, + termination_condition = RunForeverTerminationCondition(), + ) +end + +end \ No newline at end of file diff --git a/src/ComputerAdaptiveTesting.jl b/src/ComputerAdaptiveTesting.jl index 328a71c..38137bf 100644 --- a/src/ComputerAdaptiveTesting.jl +++ b/src/ComputerAdaptiveTesting.jl @@ -37,9 +37,10 @@ include("./CatConfig.jl") include("./Sim.jl") include("./decision_tree/DecisionTree.jl") -# Stateful layer and comparison +# Stateful layer, compat, and comparison include("./Stateful.jl") -include("./Comparison.jl") +include("./Compat/Compat.jl") +include("./Comparison/Comparison.jl") @reexport using .CatConfig: CatLoopConfig, CatRules @reexport using .Sim: run_cat diff --git a/test/compat.jl b/test/compat.jl new file mode 100644 index 0000000..173d828 --- /dev/null +++ b/test/compat.jl @@ -0,0 +1,78 @@ +@testset "Compat" begin + using FittedItemBanks.DummyData: dummy_full + using FittedItemBanks: OneDimContinuousDomain, SimpleItemBankSpec, StdModel3PL, + BooleanResponse + using ComputerAdaptiveTesting.Aggregators: TrackedResponses, NullAbilityTracker + using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition + using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule + using ComputerAdaptiveTesting.Responses: BareResponses, ResponseType + using ComputerAdaptiveTesting: Stateful + using ComputerAdaptiveTesting: require_testext + using ComputerAdaptiveTesting.ItemBanks: LogItemBank + using ComputerAdaptiveTesting.NextItemRules: best_item + using ComputerAdaptiveTesting: Compat + using ResumableFunctions + using Test: @test, @testset + + #include("./dummy.jl") + #using .Dummy + using Random + + rng = Random.default_rng(42) + (item_bank, abilities, true_responses) = dummy_full( + Random.default_rng(42), + SimpleItemBankSpec(StdModel3PL(), OneDimContinuousDomain(), BooleanResponse()); + num_questions = 4, + num_testees = 1 + ) + half_responses = BareResponses( + ResponseType(item_bank), + [1, 2], + Vector{Bool}(true_responses[1:2, 1]) + ) + + @testset "CatJL" begin + log_item_bank = LogItemBank(item_bank) + tracked_responses = TrackedResponses(half_responses, log_item_bank, NullAbilityTracker()) + for method in ("EAP", "MAP", "ML") + @testset "Ability estimation $method" begin + rules = Compat.MirtCAT.assemble_rules(; + criteria="MI", + method + ) + @test -6.0 <= rules.ability_estimator(tracked_responses) <= 6.0 + end + end + for criteria in ("MI", "MEPV") + @testset "Next item $criteria" begin + rules = Compat.MirtCAT.assemble_rules(; + criteria, + method="EAP" + ) + @test best_item(rules.next_item, tracked_responses) in 3:4 + end + end + end + + @testset "CatR" begin + tracked_responses = TrackedResponses(half_responses, item_bank, NullAbilityTracker()) + for method in ("EAP", "BM", "ML") + @testset "Ability estimation $method" begin + rules = Compat.CatR.assemble_rules(; + criterion="MFI", + method + ) + @test -6.0 <= rules.ability_estimator(tracked_responses) <= 6.0 + end + end + for criterion in ("MFI", "bOpt", "MEPV", "MEI") + @testset "Next item $criterion" begin + rules = Compat.CatR.assemble_rules(; + criterion, + method="EAP" + ) + @test best_item(rules.next_item, tracked_responses) in 3:4 + end + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0610f2a..90f4823 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,4 +32,5 @@ using .Dummy include("./smoke.jl") include("./dt.jl") include("./stateful.jl") + include("./compat.jl") end From b8046ec0376d81b011a0868c174b463adebcd9f9 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 30 May 2025 01:00:06 +0300 Subject: [PATCH 08/42] Add likelihood to Stateful --- src/Stateful.jl | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/Stateful.jl b/src/Stateful.jl index a9eaf74..38450a0 100644 --- a/src/Stateful.jl +++ b/src/Stateful.jl @@ -8,7 +8,7 @@ module Stateful using DocStringExtensions using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, resp_vec -using ..Aggregators: TrackedResponses, Aggregators +using ..Aggregators: TrackedResponses, Aggregators, pdf, distribution_estimator using ..CatConfig: CatLoopConfig, CatRules using ..Responses: BareResponses, Response, Responses using ..NextItemRules: compute_criteria, best_item @@ -45,6 +45,7 @@ $(FUNCTIONNAME)(config::StatefulCat) -> AbstractVector{IndexT} Return a vector of indices of the sorted from best to worst item according to the CAT. """ function ranked_items end +function ranked_items(::StatefulCat) nothing end """ ```julia @@ -56,6 +57,7 @@ Returns a vector of criteria values for each item in the item bank. The criteria can vary, but should attempt to interoperate with ComputerAdaptiveTesting.jl. """ function item_criteria end +function item_criteria(::StatefulCat) nothing end """ ```julia @@ -124,6 +126,15 @@ but should attempt to interoperate with ComputerAdaptiveTesting.jl. """ function get_ability end +""" +```julia +$(FUNCTIONNAME)(config::StatefulCat, ability::AbilityT) -> Float64 +``` + +TODO +""" +function likelihood end + """ ```julia $(FUNCTIONNAME)(config::StatefulCat) @@ -199,8 +210,12 @@ function next_item(config::StatefulCatConfig) end function ranked_items(config::StatefulCatConfig) - return sortperm(compute_criteria( - config.rules.next_item, config.tracked_responses[])) + criteria = compute_criteria( + config.rules.next_item, config.tracked_responses[]) + if criteria === nothing + return nothing + end + return sortperm(criteria) end function item_criteria(config::StatefulCatConfig) @@ -240,6 +255,10 @@ function get_ability(config::StatefulCatConfig) return (config.rules.ability_estimator(config.tracked_responses[]), nothing) end +function likelihood(config::StatefulCatConfig, ability) + pdf(distribution_estimator(config.rules.ability_estimator), config.tracked_responses[], ability) +end + function item_bank_size(config::StatefulCatConfig) return length(config.tracked_responses[].item_bank) end From 19c0db045430189f1e34cefcf45fcbdb0e598ee2 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 30 May 2025 01:01:45 +0300 Subject: [PATCH 09/42] Add SafeLikelihoodEstimator --- src/aggregators/ability_estimator.jl | 51 +++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/aggregators/ability_estimator.jl b/src/aggregators/ability_estimator.jl index 87a372e..da468cc 100644 --- a/src/aggregators/ability_estimator.jl +++ b/src/aggregators/ability_estimator.jl @@ -11,6 +11,8 @@ function Integrators.normdenom(rett::IntReturnType, rett(integrator(IntegralCoeffs.one, 0, est, tracked_responses)) end +# This is not type piracy, but maybe a slightly distasteful overload +# TODO: Fix this interface? function pdf(ability_est::DistributionAbilityEstimator, tracked_responses::TrackedResponses, x) @@ -42,6 +44,53 @@ function pdf(est::PriorAbilityEstimator, AbilityLikelihood(tracked_responses)) end +function multiple_response_types_guard(tracked_responses) + if length(tracked_responses.responses.values) == 0 + return false + end + seen_value = tracked_responses.responses.values[1] + for value in tracked_responses.responses.values + if value !== seen_value + return true + end + end + return false +end + +struct GuardedAbilityEstimator{T <: DistributionAbilityEstimator, U <: DistributionAbilityEstimator, F} <: DistributionAbilityEstimator + est::T + fallback::U + guard::F +end + +function pdf(est::GuardedAbilityEstimator, + tracked_responses::TrackedResponses) + if est.guard(tracked_responses) + return pdf(est.est, tracked_responses) + else + return pdf(est.fallback, tracked_responses) + end +end + +function SafeLikelihoodAbilityEstimator(args...; kwargs...) + GuardedAbilityEstimator( + LikelihoodAbilityEstimator(), + PriorAbilityEstimator(args...), + multiple_response_types_guard + ) +end + +unlog(x) = x +unlog(x::Logarithmic{T}) where {T} = T(x) +unlog(x::ULogarithmic{T}) where {T} = T(x) +unlog(x::AbstractVector{Logarithmic{T}}) where {T} = T.(x) +unlog(x::AbstractVector{ULogarithmic{T}}) where {T} = T.(x) +#=unlog(x::ErrorIntegrationResult{Logarithmic{T}}) where {T} = T(x) +unlog(x::ErrorIntegrationResult{ULogarithmic{T}}) where {T} = T(x) +unlog(x::ErrorIntegrationResult{AbstractVector{Logarithmic{T}}}) where {T} = T.(x) +unlog(x::ErrorIntegrationResult{AbstractVector{ULogarithmic{T}}}) where {T} = T.(x) +=# + function expectation(rett::IntReturnType, f::F, ncomp, @@ -49,7 +98,7 @@ function expectation(rett::IntReturnType, est::DistributionAbilityEstimator, tracked_responses::TrackedResponses, denom = normdenom(rett, integrator, est, tracked_responses)) where {F} - rett(integrator(f, ncomp, est, tracked_responses)) / denom + unlog(rett(integrator(f, ncomp, est, tracked_responses)) / denom) end function expectation(f::F, From 1a6b7d222fb80227cd01177c16999ea7ec8e634c Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 30 May 2025 01:09:18 +0300 Subject: [PATCH 10/42] Various changes * Add pointwise/category information criterion * Make a bunch of stuff generic for number * Other stuff too --- src/aggregators/Aggregators.jl | 12 +- src/aggregators/optimizers.jl | 3 +- src/aggregators/riemann.jl | 33 +++-- src/logitembank.jl | 6 +- src/next_item_rules/NextItemRules.jl | 22 ++-- .../combinators/expectation.jl | 14 ++- src/next_item_rules/combinators/likelihood.jl | 88 ++++++++++++- .../criteria/item/information.jl | 14 ++- src/next_item_rules/criteria/item/urry.jl | 5 + .../criteria/pointwise/information.jl | 119 ++++++++++++++++++ .../information_special.jl | 0 .../information_support.jl | 28 ++++- src/next_item_rules/criteria/pointwise/kl.jl | 5 +- .../criteria/state/ability_variance.jl | 19 +-- src/next_item_rules/porcelain/aliases.jl | 98 --------------- src/next_item_rules/prelude/abstract.jl | 21 +++- src/next_item_rules/prelude/criteria.jl | 50 ++++++-- src/next_item_rules/prelude/next_item_rule.jl | 26 ++-- src/next_item_rules/strategies/exhaustive.jl | 4 +- src/precompiles.jl | 8 +- 20 files changed, 385 insertions(+), 190 deletions(-) create mode 100644 src/next_item_rules/criteria/pointwise/information.jl rename src/next_item_rules/criteria/{item => pointwise}/information_special.jl (100%) rename src/next_item_rules/criteria/{item => pointwise}/information_support.jl (77%) diff --git a/src/aggregators/Aggregators.jl b/src/aggregators/Aggregators.jl index b73e188..e2578eb 100644 --- a/src/aggregators/Aggregators.jl +++ b/src/aggregators/Aggregators.jl @@ -10,6 +10,7 @@ using StaticArrays: SVector using Distributions: Distribution, Normal, Distributions using Base.Threads using ForwardDiff: ForwardDiff +using LogarithmicNumbers: Logarithmic, ULogarithmic using FittedItemBanks: AbstractItemBank, ContinuousDomain, DichotomousSmoothedItemBank, DiscreteIndexableDomain, @@ -24,12 +25,14 @@ using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome, find1_type_sloppy using PsychometricsBazaarBase.Integrators: Integrators, BareIntegrationResult, - FixedGridIntegrator, IntReturnType, + FixedGridIntegrator, + IntReturnType, IntValue, Integrator, PreallocatedFixedGridIntegrator, normdenom using PsychometricsBazaarBase.Optimizers: OneDimOptimOptimizer, Optimizer using PsychometricsBazaarBase.ConstDistributions: std_normal, std_mv_normal +import Distributions: pdf import FittedItemBanks import PsychometricsBazaarBase.IntegralCoeffs @@ -38,7 +41,8 @@ export AbilityEstimator, TrackedResponses export AbilityTracker, NullAbilityTracker, PointAbilityTracker, GriddedAbilityTracker export ClosedFormNormalAbilityTracker, track! export response_expectation, expectation, distribution_estimator -export PointAbilityEstimator, PriorAbilityEstimator, LikelihoodAbilityEstimator +export PointAbilityEstimator, PriorAbilityEstimator +export SafeLikelihoodAbilityEstimator, LikelihoodAbilityEstimator export ModeAbilityEstimator, MeanAbilityEstimator export Speculator, replace_speculation!, normdenom, maybe_tracked_ability_estimate export AbilityIntegrator, AbilityOptimizer @@ -70,6 +74,10 @@ end abstract type DistributionAbilityEstimator <: AbilityEstimator end function DistributionAbilityEstimator(bits...) @returnsome find1_instance(DistributionAbilityEstimator, bits) + point_ability_estimator = find1_instance(PointAbilityEstimator, bits) + if point_ability_estimator !== nothing + return distribution_estimator(point_ability_estimator) + end end abstract type PointAbilityEstimator <: AbilityEstimator end diff --git a/src/aggregators/optimizers.jl b/src/aggregators/optimizers.jl index 2c45006..6585c28 100644 --- a/src/aggregators/optimizers.jl +++ b/src/aggregators/optimizers.jl @@ -47,5 +47,6 @@ function (optim::AbilityOptimizer)(f::F, est, tracked_responses::TrackedResponses; kwargs...) where {F} - optim(maybe_apply_prior(f, est), AbilityLikelihood(tracked_responses); kwargs...) + #optim(maybe_apply_prior(f, est), AbilityLikelihood(tracked_responses); kwargs...) + optim(f, pdf(est, tracked_responses); kwargs...) end diff --git a/src/aggregators/riemann.jl b/src/aggregators/riemann.jl index 243e0ae..cf5ea82 100644 --- a/src/aggregators/riemann.jl +++ b/src/aggregators/riemann.jl @@ -26,13 +26,32 @@ function (integrator::RiemannEnumerationIntegrator)(f::F, return BareIntegrationResult(result) end -function (integrator::Union{RiemannEnumerationIntegrator, FunctionIntegrator})(f::F, - ncomp, - est, - tracked_responses::TrackedResponses; - kwargs...) where {F} - integrator(maybe_apply_prior(f, est), +function (integrator::RiemannEnumerationIntegrator)( + f::F, + ncomp, + est, + tracked_responses::TrackedResponses; + kwargs... +) where {F} + integrator( + maybe_apply_prior(f, est), ncomp, AbilityLikelihood(tracked_responses); - kwargs...) + kwargs... + ) +end + +function (integrator::FunctionIntegrator)( + f::F, + ncomp, + est, + tracked_responses::TrackedResponses; + kwargs... +) where {F} + integrator( + f, + ncomp, + pdf(est, tracked_responses); + kwargs... + ) end diff --git a/src/logitembank.jl b/src/logitembank.jl index 23312e5..bfb19c1 100644 --- a/src/logitembank.jl +++ b/src/logitembank.jl @@ -21,18 +21,18 @@ inner_ir(ir::ItemResponse{<:LogItemBank}) = ItemResponse(ir.item_bank.inner, ir. ## TODO: Support item banks with other response types e.g. Float32 function FittedItemBanks.resp(ir::ItemResponse{<:LogItemBank}, θ) - exp(ULogarithmic{Float64}, FittedItemBanks.log_resp(inner_ir(ir), θ)) + exp(ULogarithmic, FittedItemBanks.log_resp(inner_ir(ir), θ)) end function FittedItemBanks.resp(ir::ItemResponse{<:LogItemBank}, response, θ) exp( - ULogarithmic{Float64}, + ULogarithmic, FittedItemBanks.log_resp(inner_ir(ir), response, θ) ) end function FittedItemBanks.resp_vec(ir::ItemResponse{<:LogItemBank}, θ) - exp.(ULogarithmic{Float64}, FittedItemBanks.log_resp_vec(inner_ir(ir), θ)) + exp.(ULogarithmic, FittedItemBanks.log_resp_vec(inner_ir(ir), θ)) end @forward LogItemBank.inner Base.length, diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index 48565c9..ac97c89 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -11,7 +11,7 @@ Springer, New York, NY. module NextItemRules using DocStringExtensions: FUNCTIONNAME, TYPEDEF, TYPEDFIELDS -using PsychometricsBazaarBase.Parameters: @with_kw +using PsychometricsBazaarBase.Parameters using LinearAlgebra: det, tr using Random: AbstractRNG, Xoshiro @@ -19,14 +19,14 @@ using ..Responses: BareResponses using ..ConfigBase using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome, find1_instance, find1_type -using PsychometricsBazaarBase.Integrators: Integrator +using PsychometricsBazaarBase.Integrators: Integrator, intval using PsychometricsBazaarBase: Integrators import PsychometricsBazaarBase.IntegralCoeffs using FittedItemBanks: AbstractItemBank, DiscreteDomain, DomainType, ItemResponse, OneDimContinuousDomain, domdims, item_params, resp, resp_vec, responses using ..Aggregators -using ..Aggregators: covariance_matrix +using ..Aggregators: covariance_matrix, FunctionProduct using Distributions: logccdf, logcdf, pdf using Base.Threads @@ -38,13 +38,17 @@ import ForwardDiff export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread export NextItemRule, ItemStrategyNextItemRule export UrryItemCriterion, InformationItemCriterion +export LikelihoodWeightedItemCriterion, PointItemCriterion +export LikelihoodWeightedItemCategoryCriterion, PointItemCategoryCriterion +export ObservedInformationPointwiseItemCategoryCriterion +export RawEmpiricalInformationPointwiseItemCategoryCriterion +export EmpiricalInformationPointwiseItemCategoryCriterion +export TotalItemInformation export RandomNextItemRule export PiecewiseNextItemRule, MemoryNextItemRule, FixedFirstItemNextItemRule export ExhaustiveSearch -export catr_next_item_aliases export preallocate -export compute_criteria, compute_criterion, compute_multi_criterion, - compute_pointwise_criterion +export compute_criteria, compute_criterion, compute_multi_criterion export best_item export PointResponseExpectation, DistributionResponseExpectation export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer @@ -70,15 +74,15 @@ include("./combinators/scalarizers.jl") include("./combinators/likelihood.jl") # Criteria -include("./criteria/item/information_special.jl") -include("./criteria/item/information_support.jl") include("./criteria/item/information.jl") include("./criteria/item/urry.jl") include("./criteria/state/ability_variance.jl") +include("./criteria/pointwise/information_special.jl") +include("./criteria/pointwise/information_support.jl") +include("./criteria/pointwise/information.jl") include("./criteria/pointwise/kl.jl") # Porcelain include("./porcelain/porcelain.jl") -include("./porcelain/aliases.jl") end diff --git a/src/next_item_rules/combinators/expectation.jl b/src/next_item_rules/combinators/expectation.jl index 61ac76f..87c945b 100644 --- a/src/next_item_rules/combinators/expectation.jl +++ b/src/next_item_rules/combinators/expectation.jl @@ -67,7 +67,7 @@ item 1-ply ahead. """ struct ExpectationBasedItemCriterion{ ResponseExpectationT <: ResponseExpectation, - CriterionT <: Union{StateCriterion, ItemCriterion} + CriterionT <: Union{StateCriterion, ItemCriterion, ItemCategoryCriterion}, } <: ItemCriterion response_expectation::ResponseExpectationT criterion::CriterionT @@ -75,7 +75,8 @@ end function _get_some_criterion(bits...; kwargs...) @returnsome StateCriterion(bits...; kwargs...) - @returnsome ItemCriterion(bits...; kwargs...) + @returnsome ItemCriterion(bits...; skip_expectation=true, kwargs...) + @returnsome ItemCategoryCriterion(bits...) end function ExpectationBasedItemCriterion(bits...; @@ -95,13 +96,16 @@ function init_thread(::ExpectationBasedItemCriterion, responses::TrackedResponse Speculator(responses, 1) end -function _generic_criterion(criterion::StateCriterion, tracked_responses, item_idx) +function _generic_criterion(criterion::StateCriterion, tracked_responses, _item_idx, _response) compute_criterion(criterion, tracked_responses) end # TODO: Support init_thread for wrapped ItemCriterion -function _generic_criterion(criterion::ItemCriterion, tracked_responses, item_idx) +function _generic_criterion(criterion::ItemCriterion, tracked_responses, item_idx, _response) compute_criterion(criterion, tracked_responses, item_idx) end +function _generic_criterion(criterion::ItemCategoryCriterion, tracked_responses, item_idx, response) + compute_criterion(criterion, tracked_responses, item_idx, response) +end function compute_criterion( item_criterion::ExpectationBasedItemCriterion, @@ -116,7 +120,7 @@ function compute_criterion( for (prob, possible_response) in zip(exp_resp, possible_responses) replace_speculation!(speculator, SVector(item_idx), SVector(possible_response)) res += prob * - _generic_criterion(item_criterion.criterion, speculator.responses, item_idx) + _generic_criterion(item_criterion.criterion, speculator.responses, item_idx, possible_response) end res end diff --git a/src/next_item_rules/combinators/likelihood.jl b/src/next_item_rules/combinators/likelihood.jl index 03da6a6..6da661c 100644 --- a/src/next_item_rules/combinators/likelihood.jl +++ b/src/next_item_rules/combinators/likelihood.jl @@ -8,12 +8,90 @@ struct LikelihoodWeightedItemCriterion{ estimator::AbilityEstimatorT end +function LikelihoodWeightedItemCriterion(bits...) + @requiresome dist_est_integrator_pair = get_dist_est_and_integrator(bits...) + (dist_est, integrator) = dist_est_integrator_pair + criterion = PointwiseItemCriterion(bits...) + return LikelihoodWeightedItemCriterion(criterion, integrator, dist_est) +end + function compute_criterion( - lwic::LikelihoodWeightedItemCriterion, - tracked_responses::TrackedResponses, - item_idx + lwic::LikelihoodWeightedItemCriterion, + tracked_responses::TrackedResponses, + item_idx ) func = FunctionProduct( - pdf(lwic.estimator, tracked_responses), lwic.criterion(tracked_responses, item_idx)) - lwic.integrator(func, 0, lwic.estimator, tracked_responses) + pdf(lwic.estimator, tracked_responses), ability -> compute_criterion(lwic.criterion, tracked_responses, item_idx, ability)) + intval(lwic.integrator(func, 0, lwic.estimator, tracked_responses)) +end + +struct PointItemCriterion{ + PointwiseItemCriterionT <: PointwiseItemCriterion, + AbilityEstimatorT <: PointAbilityEstimator +} <: ItemCriterion + criterion::PointwiseItemCriterionT + estimator::AbilityEstimatorT end + +function compute_criterion( + pic::PointItemCriterion, + tracked_responses::TrackedResponses, + item_idx +) + ability = maybe_tracked_ability_estimate( + tracked_responses, + pic.estimator + ) + return compute_criterion(pic.criterion, tracked_responses, item_idx, ability) +end + +struct LikelihoodWeightedItemCategoryCriterion{ + PointwiseItemCategoryCriterionT <: PointwiseItemCategoryCriterion, + AbilityIntegratorT <: AbilityIntegrator, + AbilityEstimatorT <: DistributionAbilityEstimator +} <: ItemCategoryCriterion + criterion::PointwiseItemCategoryCriterionT + integrator::AbilityIntegratorT + estimator::AbilityEstimatorT +end + +function LikelihoodWeightedItemCategoryCriterion(bits...) + @requiresome dist_est_integrator_pair = get_dist_est_and_integrator(bits...) + (dist_est, integrator) = dist_est_integrator_pair + criterion = PointwiseItemCategoryCriterion(bits...) + return LikelihoodWeightedItemCategoryCriterion(criterion, integrator, dist_est) +end + +function compute_criterion( + lwicc::LikelihoodWeightedItemCategoryCriterion, + tracked_responses::TrackedResponses, + item_idx, + category +) + func = FunctionProduct( + pdf(lwicc.estimator, tracked_responses), + ability -> compute_criterion(lwicc.criterion, tracked_responses, item_idx, ability, category) + ) + intval(lwicc.integrator(func, 0, lwicc.estimator, tracked_responses)) +end + +struct PointItemCategoryCriterion{ + PointwiseItemCategoryCriterionT <: PointwiseItemCategoryCriterion, + AbilityEstimatorT <: PointAbilityEstimator +} <: ItemCategoryCriterion + criterion::PointwiseItemCategoryCriterionT + estimator::AbilityEstimatorT +end + +function compute_criterion( + pic::PointItemCategoryCriterion, + tracked_responses::TrackedResponses, + item_idx, + category +) + ability = maybe_tracked_ability_estimate( + tracked_responses, + pic.estimator + ) + return compute_criterion(pic.criterion, tracked_responses, item_idx, ability, category) +end \ No newline at end of file diff --git a/src/next_item_rules/criteria/item/information.jl b/src/next_item_rules/criteria/item/information.jl index 04987e4..88f7850 100644 --- a/src/next_item_rules/criteria/item/information.jl +++ b/src/next_item_rules/criteria/item/information.jl @@ -1,12 +1,20 @@ # TODO: Should have Variants for point ability versus distribution ability -struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <: +@kw_only struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <: ItemCriterion ability_estimator::AbilityEstimatorT expected_item_information::F end -function InformationItemCriterion(ability_estimator) - InformationItemCriterion(ability_estimator, expected_item_information) +function InformationItemCriterion(ability_estimator::PointAbilityEstimator) + InformationItemCriterion(; + ability_estimator, + expected_item_information + ) +end + +function InformationItemCriterion(bits...) + @requiresome ability_estimator = PointAbilityEstimator(bits...) + InformationItemCriterion(ability_estimator) end function compute_criterion( diff --git a/src/next_item_rules/criteria/item/urry.jl b/src/next_item_rules/criteria/item/urry.jl index e71a82b..177c36f 100644 --- a/src/next_item_rules/criteria/item/urry.jl +++ b/src/next_item_rules/criteria/item/urry.jl @@ -9,6 +9,11 @@ struct UrryItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <: ItemCrit ability_estimator::AbilityEstimatorT end +function UrryItemCriterion(bits...) + @requiresome ability_estimator = PointAbilityEstimator(bits...) + UrryItemCriterion(ability_estimator) +end + # TODO: Slow + poor error handling function raw_difficulty(item_bank, item_idx) item_params(item_bank, item_idx).difficulty diff --git a/src/next_item_rules/criteria/pointwise/information.jl b/src/next_item_rules/criteria/pointwise/information.jl new file mode 100644 index 0000000..a9845e6 --- /dev/null +++ b/src/next_item_rules/criteria/pointwise/information.jl @@ -0,0 +1,119 @@ +""" +This calculates the pointwise information criterion for an item response model. +""" +struct ObservedInformationPointwiseItemCategoryCriterion <: PointwiseItemCategoryCriterion end + +function compute_criterion( + ::ObservedInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability, + category +) + actual = -double_derivative((ability -> log_resp(ir, category, ability)), ability) .* resp(ir, category, ability) + -actual +end + +function compute_criterion_vec( + ::ObservedInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability +) + actual = -double_derivative((ability -> log_resp_vec(ir, ability)), ability) .* resp_vec(ir, ability) + -actual +end + +""" +See EmpiricalInformationPointwiseItemCategoryCriterion for more details. +""" +struct RawEmpiricalInformationPointwiseItemCategoryCriterion <: PointwiseItemCategoryCriterion end + +function compute_criterion( + ::RawEmpiricalInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability, + category +) + actual = ForwardDiff.derivative(ability -> resp(ir, category, ability), ability) ^ 2 / resp(ir, category, ability) + -actual +end + +function compute_criterion_vec( + ::RawEmpiricalInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability +) + actual = ForwardDiff.derivative(ability -> resp_vec(ir, ability), ability) .^ 2 ./ resp_vec(ir, ability) + -actual +end + +""" +In equation 10 of [1] we see that we can compute information using 2nd derivatives of log likelihood or 1st derivative squared. +For single categories, we need to an extra term which disappears when we calculate the total see [2]. +For this reason +`RawEmpiricalInformationPointwiseItemCategoryCriterion` +computes without this factor, while +`EmpiricalInformationPointwiseItemCategoryCriterion` +computes with it. + +So in general, only use the former with `TotalItemInformation` + +[1] +``Information Functions of the Generalized Partial Credit Model'' +Eiji Muraki +https://doi.org/10.1177/014662169301700403 + +[2] +https://mark.reid.name/blog/fisher-information-and-log-likelihood.html +""" +struct EmpiricalInformationPointwiseItemCategoryCriterion <: PointwiseItemCategoryCriterion end + +function compute_criterion( + ::EmpiricalInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability, + category +) + actual = -compute_criterion( + RawEmpiricalInformationPointwiseItemCategoryCriterion(), + ir, + ability, + category + ) .- double_derivative((ability -> resp(ir, category, ability)), ability) + -actual +end + +function compute_criterion_vec( + ::EmpiricalInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability +) + actual = -compute_criterion_vec( + RawEmpiricalInformationPointwiseItemCategoryCriterion(), + ir, + ability + ) .- double_derivative((ability -> resp_vec(ir, ability)), ability) + -actual +end + +#= +""" +This implements Fisher information as a pointwise item criterion. +It uses ForwardDiff to find the second derivative of the log prob for the current item and ability estimate. +It then uses the expected outcome at the given ability estimate to weight the outcomes. + +\[ +E_{\thetaHAT}(log(\frac{d^2 log\thetaHAT}{d\theta)) +\] +""" +=# +struct TotalItemInformation{PointwiseItemCategoryCriterionT <: PointwiseItemCategoryCriterion} <: PointwiseItemCriterion + pcic::PointwiseItemCategoryCriterionT +end + +function compute_criterion( + tii::TotalItemInformation, + ir::ItemResponse, + ability +) + sum(compute_criterion_vec(tii.pcic, ir, ability)) +end \ No newline at end of file diff --git a/src/next_item_rules/criteria/item/information_special.jl b/src/next_item_rules/criteria/pointwise/information_special.jl similarity index 100% rename from src/next_item_rules/criteria/item/information_special.jl rename to src/next_item_rules/criteria/pointwise/information_special.jl diff --git a/src/next_item_rules/criteria/item/information_support.jl b/src/next_item_rules/criteria/pointwise/information_support.jl similarity index 77% rename from src/next_item_rules/criteria/item/information_support.jl rename to src/next_item_rules/criteria/pointwise/information_support.jl index c63e6af..4e798be 100644 --- a/src/next_item_rules/criteria/item/information_support.jl +++ b/src/next_item_rules/criteria/pointwise/information_support.jl @@ -1,6 +1,6 @@ using FittedItemBanks: CdfMirtItemBank, - GuessItemBank, SlipItemBank, TransferItemBank, AnySlipOrGuessItemBank -using FittedItemBanks: inner_item_response, norm_abil, y_offset, irf_size + TransferItemBank, GuessAndSlipItemBank +using FittedItemBanks: inner_item_response, norm_abil, irf_size using StatsFuns: logaddexp function log_resp_vec(ir::ItemResponse{<:TransferItemBank}, θ) @@ -30,9 +30,10 @@ function log_resp(ir::ItemResponse{<:CdfMirtItemBank}, val, θ) end end +#= # XXX: Not sure if this is optimal numerically or speed wise -- possibly it # would be better to just transform to linear space in this case? -@inline function log_transform_irf_y(guess::Float64, slip::Float64, y) +@inline function log_transform_irf_y(guess, slip, y) # log space version of guess + irf_size(guess, slip) * y logaddexp(log(guess), log(irf_size(guess, slip)) + y) end @@ -63,6 +64,11 @@ end function log_resp(ir::ItemResponse{<:AnySlipOrGuessItemBank}, val, θ) log_transform_irf_y(ir, val, log_resp(inner_item_response(ir), val, θ)) end +=# + +log_resp(ir::ItemResponse{<:GuessAndSlipItemBank}, response, θ) = log(resp(ir, response, θ)) +log_resp(ir::ItemResponse{<:GuessAndSlipItemBank}, θ) = log(resp(ir, θ)) +log_resp_vec(ir::ItemResponse{<:GuessAndSlipItemBank}, θ) = log.(resp_vec(ir, θ)) function vector_hessian(f, x, n) out = ForwardDiff.jacobian(x -> ForwardDiff.jacobian(f, x), x) @@ -73,7 +79,7 @@ function double_derivative(f, x) ForwardDiff.derivative(x -> ForwardDiff.derivative(f, x), x) end -function expected_item_information(ir::ItemResponse, θ::Float64) +function expected_item_information(ir::ItemResponse, θ::Number) exp_resp = resp_vec(ir, θ) d² = double_derivative((θ -> log_resp_vec(ir, θ)), θ) -sum(exp_resp .* d²) @@ -81,7 +87,7 @@ end # TODO: Unclear whether this should be implemented with ExpectationBasedItemCriterion # TODO: This is not implementing DRule but postposterior DRule -function expected_item_information(ir::ItemResponse, θ::Vector{Float64}) +function expected_item_information(ir::ItemResponse, θ::Vector) exp_resp = resp_vec(ir, θ) n = domdims(ir.item_bank) hess = vector_hessian(θ -> log_resp_vec(ir, θ), θ, n) @@ -99,3 +105,15 @@ function responses_information(item_bank::AbstractItemBank, responses::BareRespo for (resp_idx, resp_value) in zip(responses.indices, responses.values)); init = zeros(d, d)) end + +using ComputerAdaptiveTesting: ItemBanks + +function log_resp_vec(ir::ItemResponse{<:ItemBanks.LogItemBank}, θ) + # XXX: Should not destruct the logarithmic number here + # Works for now + log.(resp_vec(ItemBanks.inner_ir(ir), θ)) +end + +function log_resp(ir::ItemResponse{<:ItemBanks.LogItemBank}, resp, θ) + log(resp(ItemBanks.inner_ir(ir), resp, θ)) +end \ No newline at end of file diff --git a/src/next_item_rules/criteria/pointwise/kl.jl b/src/next_item_rules/criteria/pointwise/kl.jl index 630680c..efaf115 100644 --- a/src/next_item_rules/criteria/pointwise/kl.jl +++ b/src/next_item_rules/criteria/pointwise/kl.jl @@ -22,10 +22,11 @@ function PosteriorExpectedKLInformationItemCriterion(bits...) point_estimator, distribution_estimator, integrator) end -function compute_pointwise_criterion( +function compute_criterion( item_criterion::PosteriorExpectedKLInformationItemCriterion, tracked_responses::TrackedResponses, - item_idx) + item_idx, + theta) theta_0 = maybe_tracked_ability_estimate(tracked_responses, item_criterion.point_estimator) item_response = ItemResponse(tracked_responses.item_bank, item_idx) diff --git a/src/next_item_rules/criteria/state/ability_variance.jl b/src/next_item_rules/criteria/state/ability_variance.jl index 343a873..6e44b49 100644 --- a/src/next_item_rules/criteria/state/ability_variance.jl +++ b/src/next_item_rules/criteria/state/ability_variance.jl @@ -14,25 +14,10 @@ struct AbilityVarianceStateCriterion{ skip_zero::Bool end -function _get_dist_est_and_integrator(bits...) - # XXX: Weakness in this initialisation system is showing now - # This needs ot be explicitly passed dist_est and integrator, but this may - # be burried within a MeanAbilityEstimator - dist_est = DistributionAbilityEstimator(bits...) - integrator = AbilityIntegrator(bits...) - if dist_est !== nothing && integrator !== nothing - return (dist_est, integrator) - end - # So let's just handle this case individually for now - # (Is this going to cause a problem with this being picked over something more appropriate?) - @requiresome mean_ability_est = MeanAbilityEstimator(bits...) - return (mean_ability_est.dist_est, mean_ability_est.integrator) -end - function AbilityVarianceStateCriterion(bits...) skip_zero = false @returnsome find1_instance(AbilityVarianceStateCriterion, bits) - @requiresome dist_est_integrator_pair = _get_dist_est_and_integrator(bits...) + @requiresome dist_est_integrator_pair = get_dist_est_and_integrator(bits...) (dist_est, integrator) = dist_est_integrator_pair return AbilityVarianceStateCriterion(dist_est, integrator, skip_zero) end @@ -94,7 +79,7 @@ end function AbilityCovarianceStateMultiCriterion(bits...) skip_zero = false - @requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...) + @requiresome (dist_est, integrator) = get_dist_est_and_integrator(bits...) return AbilityCovarianceStateMultiCriterion(dist_est, integrator, skip_zero) end diff --git a/src/next_item_rules/porcelain/aliases.jl b/src/next_item_rules/porcelain/aliases.jl index 392b9a5..e69de29 100644 --- a/src/next_item_rules/porcelain/aliases.jl +++ b/src/next_item_rules/porcelain/aliases.jl @@ -1,98 +0,0 @@ -""" -This mapping provides next item rules through the same names that they are -available through in the `catR` R package. TODO compability with `mirtcat` -""" -const catr_next_item_aliases = Dict( - "MFI" => (ability_estimator; parallel = true) -> ItemStrategyNextItemRule( - ExhaustiveSearch(parallel), - InformationItemCriterion(ability_estimator)), - "bOpt" => (ability_estimator; parallel = true) -> ItemStrategyNextItemRule( - ExhaustiveSearch(parallel), - UrryItemCriterion(ability_estimator)), - "MEPV" => (ability_estimator; parallel = true) -> ItemStrategyNextItemRule( - ExhaustiveSearch(parallel), - ExpectationBasedItemCriterion(ability_estimator, - AbilityVarianceStateCriterion(ability_estimator))) #"MLWI", #"MPWI", #"MEI", -) - -#"thOpt", -#"progressive", -#"proportional", -#"KL", -#"KLP", -#"GDI", -#"GDIP", -#"random" - -function _mirtcat_helper(item_criterion_callback) - function _helper(bits...; ability_estimator = nothing) - ability_estimator = AbilityEstimator(bits...; ability_estimator = ability_estimator) - item_criterion = item_criterion_callback( - [bits..., ability_estimator], ability_estimator) - return ItemStrategyNextItemRule(ExhaustiveSearch(), item_criterion) - end - return _helper -end - -const mirtcat_next_item_aliases = Dict( - # "MI' for the maximum information - "MI" => _mirtcat_helper((bits, ability_estimator) -> InformationItemCriterion(ability_estimator)), - # 'MEPV' for minimum expected posterior variance - "MEPV" => _mirtcat_helper((bits, ability_estimator) -> ExpectationBasedItemCriterion( - ability_estimator, - AbilityVarianceStateCriterion(bits...))), - "Drule" => _mirtcat_helper((bits, ability_estimator) -> DRuleItemCriteron(ability_estimator)), - "Trule" => _mirtcat_helper((bits, ability_estimator) -> TRuleItemCriteron(ability_estimator)) -) - -# 'MLWI' for maximum likelihood weighted information -#"MLWI" => _mirtcat_helper((bits, ability_estimator) -> InformationItemCriterion(ability_estimator)) -# 'MPWI' for maximum posterior weighted information -# 'MEI' for maximum expected information -# 'IKLP' as well as 'IKL' for the integration based Kullback-Leibler criteria with and without the prior density weight, -# respectively, and their root-n items administered weighted counter-parts, 'IKLn' and 'IKLPn'. -#= -Possible inputs for multidimensional adaptive tests include: 'Drule' for the -maximum determinant of the information matrix, 'Trule' for the maximum -(potentially weighted) trace of the information matrix, 'Arule' for the minimum (potentially weighted) trace of the asymptotic covariance matrix, 'Erule' -for the minimum value of the information matrix, and 'Wrule' for the weighted -information criteria. For each of these rules, the posterior weight for the latent trait scores can also be included with the 'DPrule', 'TPrule', 'APrule', -'EPrule', and 'WPrule', respectively. -Applicable to both unidimensional and multidimensional tests are the 'KL' and -'KLn' for point-wise Kullback-Leibler divergence and point-wise KullbackLeibler with a decreasing delta value (delta*sqrt(n), where n is the number -of items previous answered), respectively. The delta criteria is defined in the -design object -Non-adaptive methods applicable even when no mo object is passed are: 'random' -to randomly select items, and 'seq' for selecting items sequentially -=# - -const mirtcat_ability_estimator_aliases = Dict( -# "MAP" for the maximum a-posteriori (i.e, Bayes modal) -# "ML" for maximum likelihood -# "WLE" for weighted likelihood estimation -# "EAPsum" for the expected a-posteriori for each sum score -# "EAP" for the expected a-posteriori (default). -) - -#= -• "plausible" for a single plausible value imputation for each case. This is -equivalent to setting plausible.draws = 1 -• "classify" for the posteriori classification probabilities (only applicable -when the input model was of class MixtureClass) -=# - -function mirtcat_quadpts(nfact) - if nfact == 1 - return 121 - elseif nfact == 2 - return 61 - elseif nfact == 3 - return 31 - elseif nfact == 4 - return 19 - elseif nfact == 5 - return 11 - else - return 5 - end -end diff --git a/src/next_item_rules/prelude/abstract.jl b/src/next_item_rules/prelude/abstract.jl index fdc68c9..5f3c665 100644 --- a/src/next_item_rules/prelude/abstract.jl +++ b/src/next_item_rules/prelude/abstract.jl @@ -29,21 +29,32 @@ abstract type NextItemStrategy <: CatConfigBase end """ $(TYPEDEF) -Abstract type for next item criteria +Abstract base type all criteria should inherit from """ -abstract type ItemCriterion <: CatConfigBase end +abstract type CriterionBase <: CatConfigBase end +abstract type ItemCriterionBase <: CatConfigBase end + +abstract type ItemCriterion <: ItemCriterionBase end + +""" +$(TYPEDEF) +""" +abstract type StateCriterion <: CriterionBase end """ $(TYPEDEF) """ -abstract type StateCriterion <: CatConfigBase end +abstract type PointwiseItemCriterion <: ItemCriterionBase end """ $(TYPEDEF) """ -abstract type PointwiseItemCriterion <: CatConfigBase end +abstract type ItemCategoryCriterion <: ItemCriterionBase end -abstract type PurePointwiseItemCriterion <: PointwiseItemCriterion end +""" +$(TYPEDEF) +""" +abstract type PointwiseItemCategoryCriterion <: ItemCriterionBase end abstract type MatrixScalarizer end abstract type StateMultiCriterion end diff --git a/src/next_item_rules/prelude/criteria.jl b/src/next_item_rules/prelude/criteria.jl index 277c65d..1e2055e 100644 --- a/src/next_item_rules/prelude/criteria.jl +++ b/src/next_item_rules/prelude/criteria.jl @@ -1,13 +1,15 @@ #= Single dimensional =# -function ItemCriterion(bits...; ability_estimator = nothing, ability_tracker = nothing) +function ItemCriterion(bits...; ability_estimator = nothing, ability_tracker = nothing, skip_expectation = false) @returnsome find1_instance(ItemCriterion, bits) @returnsome find1_type(ItemCriterion, bits) typ->typ( ability_estimator = ability_estimator, ability_tracker = ability_tracker) - @returnsome ExpectationBasedItemCriterion(bits...; - ability_estimator = ability_estimator, - ability_tracker = ability_tracker) + if !skip_expectation + @returnsome ExpectationBasedItemCriterion(bits...; + ability_estimator = ability_estimator, + ability_tracker = ability_tracker) + end end function StateCriterion(bits...; ability_estimator = nothing, ability_tracker = nothing) @@ -15,6 +17,21 @@ function StateCriterion(bits...; ability_estimator = nothing, ability_tracker = @returnsome find1_type(StateCriterion, bits) typ->typ() end +function ItemCategoryCriterion(bits...) + @returnsome find1_instance(ItemCategoryCriterion, bits) + @returnsome find1_type(ItemCategoryCriterion, bits) typ->typ() +end + +function PointwiseItemCriterion(bits...) + @returnsome find1_instance(PointwiseItemCriterion, bits) + @returnsome find1_type(PointwiseItemCriterion, bits) typ->typ() +end + +function PointwiseItemCategoryCriterion(bits...) + @returnsome find1_instance(PointwiseItemCategoryCriterion, bits) + @returnsome find1_type(PointwiseItemCategoryCriterion, bits) typ->typ() +end + function init_thread(::ItemCriterion, ::TrackedResponses) nothing end @@ -72,13 +89,9 @@ function compute_criteria( compute_criteria(rule.criterion, responses) end -function compute_pointwise_criterion( - ppic::PurePointwiseItemCriterion, tracked_responses, item_idx) - compute_pointwise_criterion(ppic, ItemResponse(tracked_responses.item_bank, item_idx)) -end - -struct PurePointwiseItemCriterionFunction{PointwiseItemCriterionT <: PointwiseItemCriterion} - item_response::ItemResponse +function compute_criterion( + ppic::ItemCriterionBase, tracked_responses::TrackedResponses, item_idx, args...) + compute_criterion(ppic, ItemResponse(tracked_responses.item_bank, item_idx), args...) end function init_thread(::ItemMultiCriterion, ::TrackedResponses) @@ -98,3 +111,18 @@ function compute_multi_criterion( state_criterion::StateMultiCriterion, ::Nothing, tracked_responses) compute_multi_criterion(state_criterion, tracked_responses) end + +function get_dist_est_and_integrator(bits...) + # XXX: Weakness in this initialisation system is showing now + # This needs ot be explicitly passed dist_est and integrator, but this may + # be burried within a MeanAbilityEstimator + dist_est = DistributionAbilityEstimator(bits...) + integrator = AbilityIntegrator(bits...) + if dist_est !== nothing && integrator !== nothing + return (dist_est, integrator) + end + # So let's just handle this case individually for now + # (Is this going to cause a problem with this being picked over something more appropriate?) + @requiresome mean_ability_est = MeanAbilityEstimator(bits...) + return (mean_ability_est.dist_est, mean_ability_est.integrator) +end diff --git a/src/next_item_rules/prelude/next_item_rule.jl b/src/next_item_rules/prelude/next_item_rule.jl index bd708e8..a5bc007 100644 --- a/src/next_item_rules/prelude/next_item_rule.jl +++ b/src/next_item_rules/prelude/next_item_rule.jl @@ -1,22 +1,20 @@ function NextItemRule(bits...; ability_estimator = nothing, - ability_tracker = nothing, - parallel = true) + ability_tracker = nothing) @returnsome find1_instance(NextItemRule, bits) @returnsome ItemStrategyNextItemRule(bits..., ability_estimator = ability_estimator, - ability_tracker = ability_tracker, - parallel = parallel) + ability_tracker = ability_tracker) end -function NextItemStrategy(; parallel = true) - ExhaustiveSearch(parallel) +function NextItemStrategy() + ExhaustiveSearch() end -function NextItemStrategy(bits...; parallel = true) +function NextItemStrategy(bits...) @returnsome find1_instance(NextItemStrategy, bits) - @returnsome find1_type(NextItemStrategy, bits) typ->typ(; parallel = parallel) - @returnsome NextItemStrategy(; parallel = parallel) + @returnsome find1_type(NextItemStrategy, bits) typ->typ() + @returnsome NextItemStrategy() end """ @@ -26,7 +24,7 @@ $(TYPEDFIELDS) `ItemStrategyNextItemRule` which together with a `NextItemStrategy` acts as an adapter by which an `ItemCriterion` can serve as a `NextItemRule`. - $(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true) + $(FUNCTIONNAME)(bits...; ability_estimator=nothing Implicit constructor for $(FUNCTIONNAME). Will default to `ExhaustiveSearch` when no `NextItemStrategy` is given. @@ -40,10 +38,9 @@ struct ItemStrategyNextItemRule{ end function ItemStrategyNextItemRule(bits...; - parallel = true, ability_estimator = nothing, ability_tracker = nothing) - strategy = NextItemStrategy(bits...; parallel = parallel) + strategy = NextItemStrategy(bits...) criterion = ItemCriterion(bits...; ability_estimator = ability_estimator, ability_tracker = ability_tracker) @@ -54,4 +51,9 @@ end function best_item(rule::NextItemRule, tracked_responses::TrackedResponses) best_item(rule, tracked_responses, tracked_responses.item_bank) +end + +# Default implementation +function compute_criteria(::NextItemRule, ::TrackedResponses) + nothing end \ No newline at end of file diff --git a/src/next_item_rules/strategies/exhaustive.jl b/src/next_item_rules/strategies/exhaustive.jl index 7b47429..4c13255 100644 --- a/src/next_item_rules/strategies/exhaustive.jl +++ b/src/next_item_rules/strategies/exhaustive.jl @@ -30,9 +30,7 @@ $(TYPEDEF) $(TYPEDFIELDS) """ -@with_kw struct ExhaustiveSearch <: NextItemStrategy - parallel::Bool = false -end +struct ExhaustiveSearch <: NextItemStrategy end function best_item( rule::ItemStrategyNextItemRule{ExhaustiveSearch, ItemCriterionT}, diff --git a/src/precompiles.jl b/src/precompiles.jl index c7b44f5..3225c8f 100644 --- a/src/precompiles.jl +++ b/src/precompiles.jl @@ -7,7 +7,8 @@ using PrecompileTools: @compile_workload, @setup_workload using Random: default_rng using .Aggregators: LikelihoodAbilityEstimator, MeanAbilityEstimator, GriddedAbilityTracker, AbilityIntegrator - using .NextItemRules: catr_next_item_aliases, preallocate + using .NextItemRules: preallocate, ExhaustiveSearch, ItemStrategyNextItemRule, + ExpectationBasedItemCriterion, AbilityVarianceStateCriterion using .Stateful: Stateful rng = default_rng(42) @@ -19,7 +20,10 @@ using PrecompileTools: @compile_workload, @setup_workload lh_grid_tracker = GriddedAbilityTracker(lh_ability_est, integrator) ability_integrator = AbilityIntegrator(integrator, lh_grid_tracker) ability_estimator = MeanAbilityEstimator(lh_ability_est, ability_integrator) - next_item_rule = catr_next_item_aliases["MEPV"](ability_estimator) + next_item_rule = ItemStrategyNextItemRule( + ExhaustiveSearch(), + ExpectationBasedItemCriterion(ability_estimator, + AbilityVarianceStateCriterion(ability_estimator))) cat = Stateful.StatefulCatConfig(CatConfig.CatRules(; next_item=next_item_rule, termination_condition=TerminationConditions.RunForeverTerminationCondition(), From cc964ef2e32a94afad4b78e8528963d93712dfbc Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 1 Jun 2025 14:10:54 +0300 Subject: [PATCH 11/42] Fix up comparison module directory --- src/{comparison => Comparison}/Comparison.jl | 0 src/{comparison => Comparison}/watchdog.jl | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename src/{comparison => Comparison}/Comparison.jl (100%) rename src/{comparison => Comparison}/watchdog.jl (100%) diff --git a/src/comparison/Comparison.jl b/src/Comparison/Comparison.jl similarity index 100% rename from src/comparison/Comparison.jl rename to src/Comparison/Comparison.jl diff --git a/src/comparison/watchdog.jl b/src/Comparison/watchdog.jl similarity index 100% rename from src/comparison/watchdog.jl rename to src/Comparison/watchdog.jl From 9055986fbad4795dc5b29618da76d56e4ffe1920 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 14:55:17 +0300 Subject: [PATCH 12/42] Remove dependency upon ResumableFunctions --- Project.toml | 4 +--- test/compat.jl | 1 - test/dummy.jl | 55 ++++++++++++++++++++---------------------------- test/runtests.jl | 1 - test/stateful.jl | 1 - 5 files changed, 24 insertions(+), 38 deletions(-) diff --git a/Project.toml b/Project.toml index 12276e5..01749db 100644 --- a/Project.toml +++ b/Project.toml @@ -62,7 +62,6 @@ PrecompileTools = "1.2.1" PsychometricsBazaarBase = "^0.8.1" Random = "^1.11" Reexport = "1" -ResumableFunctions = "^0.6, 1" Setfield = "^1" SparseArrays = "^1.11" StaticArrays = "1" @@ -75,8 +74,7 @@ julia = "^1.11" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Optim = "429524aa-4258-5aef-a3af-852621145aeb" -ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "JET", "Optim", "ResumableFunctions", "Test"] +test = ["Aqua", "JET", "Optim", "Test"] diff --git a/test/compat.jl b/test/compat.jl index 173d828..7acc663 100644 --- a/test/compat.jl +++ b/test/compat.jl @@ -11,7 +11,6 @@ using ComputerAdaptiveTesting.ItemBanks: LogItemBank using ComputerAdaptiveTesting.NextItemRules: best_item using ComputerAdaptiveTesting: Compat - using ResumableFunctions using Test: @test, @testset #include("./dummy.jl") diff --git a/test/dummy.jl b/test/dummy.jl index 5f71ddc..95b3895 100644 --- a/test/dummy.jl +++ b/test/dummy.jl @@ -10,7 +10,6 @@ using PsychometricsBazaarBase.Integrators using PsychometricsBazaarBase.Optimizers using Optim using Random -using ResumableFunctions struct DummyAbilityEstimator <: AbilityEstimator val::Any @@ -24,7 +23,7 @@ const optimizers_1d = [ FunctionOptimizer(OneDimOptimOptimizer(-6.0, 6.0, NelderMead())) ] const integrators_1d = [ - FunctionIntegrator(QuadGKIntegrator(-6, 6, 5)), + FunctionIntegrator(QuadGKIntegrator(lo=-6.0, hi=6.0, order=5)), FunctionIntegrator(FixedGKIntegrator(-6, 6, 80)) ] const ability_estimators_1d = [ @@ -46,47 +45,39 @@ const criteria_1d = [ ((), (stuff) -> RandomNextItemRule()) ] -@resumable function _get_stuffs(needed) +function _get_stuffs(needed) if :est in needed - for (extra_needed, mk_est) in ability_estimators_1d + return ( + (; stuff..., est = mk_est(stuff)) + for (extra_needed, mk_est) in ability_estimators_1d for stuff in _get_stuffs(setdiff(needed, Set((:est,))) ∪ extra_needed) - x = (; stuff..., est = mk_est(stuff)) - @yield x - end - end - return + ) end if :integrator in needed - for new_integrator in integrators_1d + return ( + (; stuff..., integrator = new_integrator) + for new_integrator in integrators_1d for stuff in _get_stuffs(setdiff(needed, Set((:integrator,)))) - x = (; stuff..., integrator = new_integrator) - @yield x - end - end - return + ) end if :optimizer in needed - pop!(needed, :optimizer) - for new_optimizer in optimizers_1d + return ( + (; stuff..., optimizer = new_optimizer) + for new_optimizer in optimizers_1d for stuff in _get_stuffs(setdiff(needed, Set((:optimizer,)))) - x = (; stuff..., optimizer = new_optimizer) - @yield x - end - end - return + ) end - x = NamedTuple() - @yield x - return + return [NamedTuple()] end -@resumable function get_stuffs(needed) - add_dummy_est = !(:est in needed) - for stuff in _get_stuffs(needed) - if add_dummy_est - stuff = (; stuff..., est = DummyAbilityEstimator(0.0)) - end - @yield stuff +function get_stuffs(needed) + if !(:est in needed) + return ( + (; stuff..., est = DummyAbilityEstimator(0.0)) + for stuff in _get_stuffs(needed) + ) + else + return _get_stuffs(needed) end end diff --git a/test/runtests.jl b/test/runtests.jl index 90f4823..75ccddc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,6 @@ using Distributions using Distributions: ZeroMeanIsoNormal, Zeros, ScalMat using Optim using Random -using ResumableFunctions using Test diff --git a/test/stateful.jl b/test/stateful.jl index 2bb84c8..5e1ff80 100644 --- a/test/stateful.jl +++ b/test/stateful.jl @@ -7,7 +7,6 @@ using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule using ComputerAdaptiveTesting: Stateful using ComputerAdaptiveTesting: require_testext - using ResumableFunctions using Test: @test, @testset include("./dummy.jl") From 8992a35e57634d832b7aadedb349d9982c5b7889 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 15:26:29 +0300 Subject: [PATCH 13/42] Fix some typos in MirtCAT compat --- src/Compat/MirtCAT.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Compat/MirtCAT.jl b/src/Compat/MirtCAT.jl index 7d5a212..9551271 100644 --- a/src/Compat/MirtCAT.jl +++ b/src/Compat/MirtCAT.jl @@ -49,8 +49,8 @@ const next_item_aliases = Dict( distribution_estimator(posterior_ability_estimator), integrator )), - "Drule" => _next_item_helper((; bits, ability_estimator, rest...) -> DRuleItemCriteron(ability_estimator)), - "Trule" => _next_item_helper((; bits, ability_estimator, rest...) -> TRuleItemCriteron(ability_estimator)) + "Drule" => _next_item_helper((; bits, ability_estimator, rest...) -> DRuleItemCriterion(ability_estimator)), + "Trule" => _next_item_helper((; bits, ability_estimator, rest...) -> TRuleItemCriterion(ability_estimator)) ) # 'IKLP' as well as 'IKL' for the integration based Kullback-Leibler criteria with and without the prior density weight, From f3d9a6b796a8cad4ddd79c53dba38ad2bbead098 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 15:27:06 +0300 Subject: [PATCH 14/42] Add multidim to MirtCAT compat --- src/Compat/MirtCAT.jl | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/src/Compat/MirtCAT.jl b/src/Compat/MirtCAT.jl index 9551271..24fd5d0 100644 --- a/src/Compat/MirtCAT.jl +++ b/src/Compat/MirtCAT.jl @@ -71,9 +71,9 @@ to randomly select items, and 'seq' for selecting items sequentially =# const ability_estimator_aliases = Dict( - "MAP" => (; optimizer, kwargs...) -> ModeAbilityEstimator(PriorAbilityEstimator(), optimizer), - "ML" => (; optimizer, kwargs...) -> ModeAbilityEstimator(SafeLikelihoodAbilityEstimator(), optimizer), - "EAP" => (; integrator, kwargs...) -> MeanAbilityEstimator(PriorAbilityEstimator(), integrator), + "MAP" => (; optimizer, ncomp, kwargs...) -> ModeAbilityEstimator(PriorAbilityEstimator(; ncomp=ncomp), optimizer), + "ML" => (; optimizer, ncomp, kwargs...) -> ModeAbilityEstimator(SafeLikelihoodAbilityEstimator(; ncomp=ncomp), optimizer), + "EAP" => (; integrator, ncomp, kwargs...) -> MeanAbilityEstimator(PriorAbilityEstimator(; ncomp=ncomp), integrator), # "WLE" for weighted likelihood estimation # "EAPsum" for the expected a-posteriori for each sum score ) @@ -116,19 +116,34 @@ function setup_optimizer(lo=-6.0, hi=6.0) # https://stats.stackexchange.com/questions/272880/algorithm-used-in-nlm-function-in-r # So just use Newton() with defaults for now # Except then we can't have box constraints so I suppose IPNewton - Optimizers.OneDimOptimOptimizer(lo, hi, Optimizers.IPNewton()) + if lo isa AbstractVector && hi isa AbstractVector + Optimizers.MultiDimOptimOptimizer(lo, hi, Optimizers.IPNewton()) + else + Optimizers.OneDimOptimOptimizer(lo, hi, Optimizers.IPNewton()) + end end function assemble_rules(; criteria = "MI", method = "MAP", - start_item = 1 + start_item = 1, + ncomp = 0 ) - integrator = setup_integrator() - optimizer = setup_optimizer() - ability_estimator = ability_estimator_aliases[method](; integrator, optimizer) - posterior_ability_estimator = PriorAbilityEstimator() - @info "assemble rules" criteria + if ncomp == 0 + lo = -6.0 + hi = 6.0 + pts = mirtcat_quadpts(1) + theta_lim = 20.0 + else + lo = fill(-6.0, ncomp) + hi = fill(6.0, ncomp) + pts = mirtcat_quadpts(ncomp) + theta_lim = fill(20.0, ncomp) + end + integrator = setup_integrator(lo, hi, pts) + optimizer = setup_optimizer(-theta_lim, theta_lim) + ability_estimator = ability_estimator_aliases[method](; integrator, optimizer, ncomp) + posterior_ability_estimator = PriorAbilityEstimator(; ncomp) raw_next_item = next_item_aliases[criteria](ability_estimator, posterior_ability_estimator, integrator, optimizer) next_item = FixedFirstItemNextItemRule(start_item, raw_next_item) CatRules(; From b8675b915ff409a000c4f48868910564bcbdb730 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 15:28:05 +0300 Subject: [PATCH 15/42] Add RandomesqueStrategy --- Project.toml | 2 + src/next_item_rules/NextItemRules.jl | 1 + src/next_item_rules/strategies/randomesque.jl | 53 +++++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 src/next_item_rules/strategies/randomesque.jl diff --git a/Project.toml b/Project.toml index 01749db..63653fd 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" PsychometricsBazaarBase = "b0d9cada-d963-45e9-a4c6-4746243987f1" +QuickHeaps = "30b38841-0f52-47f8-a5f8-18d5d4064379" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -60,6 +61,7 @@ Mmap = "^1.11" Optim = "1.7.3" PrecompileTools = "1.2.1" PsychometricsBazaarBase = "^0.8.1" +QuickHeaps = "0.2.2" Random = "^1.11" Reexport = "1" Setfield = "^1" diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index ac97c89..a301968 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -65,6 +65,7 @@ include("./prelude/preallocate.jl") # Selection strategies include("./strategies/random.jl") +include("./strategies/randomesque.jl") include("./strategies/sequential.jl") include("./strategies/exhaustive.jl") diff --git a/src/next_item_rules/strategies/randomesque.jl b/src/next_item_rules/strategies/randomesque.jl new file mode 100644 index 0000000..ec8f0dc --- /dev/null +++ b/src/next_item_rules/strategies/randomesque.jl @@ -0,0 +1,53 @@ +using QuickHeaps: BinaryHeap, FastMax, Node, get_val +using StatsBase: sample + + +function randomesque( + rng::AbstractRNG, + objective::ItemCriterion, + responses::TrackedResponses, + items::AbstractItemBank, + k::Int +) + objective_state = init_thread(objective, responses) + heap = BinaryHeap{Node{Int, Float64}}(o = FastMax) + sizehint!(heap, k) + for item_idx in eachindex(items) + if (findfirst(idx -> idx == item_idx, responses.responses.indices) !== nothing) + continue + end + + obj_val = compute_criterion(objective, objective_state, responses, item_idx) + + if length(heap) < k + push!(heap, Node(item_idx, obj_val)) + elseif obj_val < get_val(peek(heap)) + heap[1] = Node(item_idx, obj_val) + end + end + if length(heap) >= 1 + Tuple(sample(rng, heap)) + else + return (-1, Inf) + end +end + +""" +$(TYPEDEF) +$(TYPEDFIELDS) + +""" +struct RandomesqueStrategy <: NextItemStrategy + rng::AbstractRNG + k::Int +end + +RandomesqueStrategy(k::Int) = RandomesqueStrategy(Xoshiro(), k) + +function best_item( + rule::ItemStrategyNextItemRule{RandomesqueStrategy, ItemCriterionT}, + responses::TrackedResponses, + items +) where {ItemCriterionT <: ItemCriterion} + randomesque(rule.rng, rule.criterion, responses, items, rule.k)[1] +end \ No newline at end of file From 8b035bf1f177f44d7dc2ce516f19f19e292a10bf Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 15:32:37 +0300 Subject: [PATCH 16/42] Make mirtcat compat tests be based on normal item bank --- test/compat.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/compat.jl b/test/compat.jl index 7acc663..54a6e9e 100644 --- a/test/compat.jl +++ b/test/compat.jl @@ -31,8 +31,7 @@ ) @testset "CatJL" begin - log_item_bank = LogItemBank(item_bank) - tracked_responses = TrackedResponses(half_responses, log_item_bank, NullAbilityTracker()) + tracked_responses = TrackedResponses(half_responses, item_bank, NullAbilityTracker()) for method in ("EAP", "MAP", "ML") @testset "Ability estimation $method" begin rules = Compat.MirtCAT.assemble_rules(; From d82fb658622f07901aef8483764c79e6321157e2 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 15:38:26 +0300 Subject: [PATCH 17/42] Reduce default margin in test_stateful_cat_item_bank_1d_dich_ib --- ext/TestExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TestExt.jl b/ext/TestExt.jl index 1448888..cc85f27 100644 --- a/ext/TestExt.jl +++ b/ext/TestExt.jl @@ -96,7 +96,7 @@ function test_stateful_cat_item_bank_1d_dich_ib( cat::Stateful.StatefulCat, item_bank::AbstractItemBank, points=[-.78, 0.0, .78], - margin=0.05, + margin=0.01, ) if length(item_bank) != Stateful.item_bank_size(cat) error("Item bank length does not match the cat's item bank size.") From 86a3951cd66cec02b1d7eb0d37d48646b91f3a9c Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 15:38:44 +0300 Subject: [PATCH 18/42] Add test_ability to TestExt --- ext/TestExt.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ext/TestExt.jl b/ext/TestExt.jl index cc85f27..996c573 100644 --- a/ext/TestExt.jl +++ b/ext/TestExt.jl @@ -110,4 +110,24 @@ function test_stateful_cat_item_bank_1d_dich_ib( end end +function test_ability( + cat1::Stateful.StatefulCat, + cat2::Stateful.StatefulCat, + item_bank_length; + margin=0.01 +) + if item_bank_length < 4 + error("Item bank length must be at least 4.") + end + for cat in (cat1, cat2) + Stateful.add_response!(cat, 1, false) + Stateful.add_response!(cat, 2, true) + Stateful.add_response!(cat, 3, false) + Stateful.add_response!(cat, 4, true) + end + ability1 = Stateful.get_ability(cat1) + ability2 = Stateful.get_ability(cat2) + @test ability1[1] ≈ ability2[1] rtol=margin +end + end \ No newline at end of file From e48dd80a5b766d28942a2d5257095812683659ca Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 15:40:26 +0300 Subject: [PATCH 19/42] Move sim to own directory --- src/ComputerAdaptiveTesting.jl | 2 +- src/{ => sim}/Sim.jl | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) rename src/{ => sim}/Sim.jl (86%) diff --git a/src/ComputerAdaptiveTesting.jl b/src/ComputerAdaptiveTesting.jl index 38137bf..1365275 100644 --- a/src/ComputerAdaptiveTesting.jl +++ b/src/ComputerAdaptiveTesting.jl @@ -34,7 +34,7 @@ include("./TerminationConditions.jl") # Combining / running include("./CatConfig.jl") -include("./Sim.jl") +include("./sim/Sim.jl") include("./decision_tree/DecisionTree.jl") # Stateful layer, compat, and comparison diff --git a/src/Sim.jl b/src/sim/Sim.jl similarity index 86% rename from src/Sim.jl rename to src/sim/Sim.jl index 39bc9b8..a01ae6e 100644 --- a/src/Sim.jl +++ b/src/sim/Sim.jl @@ -5,7 +5,17 @@ using StatsBase using FittedItemBanks: AbstractItemBank, ResponseType using ..Responses using ..CatConfig: CatLoopConfig, CatRules -using ..Aggregators: TrackedResponses, add_response!, Aggregators +using ..Aggregators: TrackedResponses, + add_response!, + Aggregators, + AbilityIntegrator, + AbilityEstimator, + LikelihoodAbilityEstimator, + PriorAbilityEstimator, + ModeAbilityEstimator, + MeanAbilityEstimator, + LikelihoodAbilityEstimator, + RiemannEnumerationIntegrator using ..NextItemRules: compute_criteria, best_item export run_cat, prompt_response, auto_responder From 8aab64ff9ad9fe32fecd443f0df713af43f5bfc1 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 15:42:57 +0300 Subject: [PATCH 20/42] Fix multidim expected_item_information --- src/next_item_rules/criteria/pointwise/information_support.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/next_item_rules/criteria/pointwise/information_support.jl b/src/next_item_rules/criteria/pointwise/information_support.jl index 4e798be..7fa9aa7 100644 --- a/src/next_item_rules/criteria/pointwise/information_support.jl +++ b/src/next_item_rules/criteria/pointwise/information_support.jl @@ -91,7 +91,7 @@ function expected_item_information(ir::ItemResponse, θ::Vector) exp_resp = resp_vec(ir, θ) n = domdims(ir.item_bank) hess = vector_hessian(θ -> log_resp_vec(ir, θ), θ, n) - -dropdims(sum((exp_resp .* (@view hess[2, :, :])), dims = 1), dims = 1) + return -sum(eachslice(hess, dims=1) .* exp_resp) end function known_item_information(ir::ItemResponse, resp_value, θ) From 98c3b06591e0d8ff4742629386ca5aaade455fda Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 15:43:53 +0300 Subject: [PATCH 21/42] Use expected_item_information for previous items in multidim case (mirtcat compatible) --- src/next_item_rules/criteria/item/information.jl | 14 ++++++++------ .../criteria/pointwise/information_support.jl | 6 ++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/next_item_rules/criteria/item/information.jl b/src/next_item_rules/criteria/item/information.jl index 88f7850..bcdbab6 100644 --- a/src/next_item_rules/criteria/item/information.jl +++ b/src/next_item_rules/criteria/item/information.jl @@ -26,14 +26,15 @@ function compute_criterion( return -item_criterion.expected_item_information(ir, ability) end -struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: +struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F, G} <: ItemMultiCriterion ability_estimator::AbilityEstimatorT - expected_item_information::F + known_item_information::F + expected_item_information::G end function InformationMatrixCriteria(ability_estimator) - InformationMatrixCriteria(ability_estimator, expected_item_information) + InformationMatrixCriteria(ability_estimator, expected_item_information, expected_item_information) end function init_thread(item_criterion::InformationMatrixCriteria, @@ -42,7 +43,8 @@ function init_thread(item_criterion::InformationMatrixCriteria, # θ update. # TODO: Update this to use track!(...) mechanism ability = maybe_tracked_ability_estimate(responses, item_criterion.ability_estimator) - responses_information(responses.item_bank, responses.responses, ability) + responses_information(responses.item_bank, responses.responses, ability; + information_func=item_criterion.known_item_information) end function compute_multi_criterion( @@ -52,9 +54,9 @@ function compute_multi_criterion( # TODO: Add in information from the prior ability = maybe_tracked_ability_estimate( tracked_responses, item_criterion.ability_estimator) - return acc_info .+ - item_criterion.expected_item_information( + exp_info = item_criterion.expected_item_information( ItemResponse(tracked_responses.item_bank, item_idx), ability) + return acc_info .+ exp_info end should_minimize(::InformationMatrixCriteria) = false diff --git a/src/next_item_rules/criteria/pointwise/information_support.jl b/src/next_item_rules/criteria/pointwise/information_support.jl index 7fa9aa7..9f987c6 100644 --- a/src/next_item_rules/criteria/pointwise/information_support.jl +++ b/src/next_item_rules/criteria/pointwise/information_support.jl @@ -94,14 +94,16 @@ function expected_item_information(ir::ItemResponse, θ::Vector) return -sum(eachslice(hess, dims=1) .* exp_resp) end +expected_item_information(ir::ItemResponse, _, θ::Vector) = expected_item_information(ir, θ) + function known_item_information(ir::ItemResponse, resp_value, θ) -ForwardDiff.hessian(θ -> log_resp(ir, resp_value, θ), θ) end -function responses_information(item_bank::AbstractItemBank, responses::BareResponses, θ) +function responses_information(item_bank::AbstractItemBank, responses::BareResponses, θ; information_func=known_item_information) d = domdims(item_bank) reduce(.+, - (known_item_information(ItemResponse(item_bank, resp_idx), resp_value > 0, θ) + (information_func(ItemResponse(item_bank, resp_idx), resp_value > 0, θ) for (resp_idx, resp_value) in zip(responses.indices, responses.values)); init = zeros(d, d)) end From 18d0488bbc6e3a172bccf8f587ece7558e150aaf Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 15:46:14 +0300 Subject: [PATCH 22/42] Fix dispatch for criteria --- src/next_item_rules/prelude/abstract.jl | 10 +++++----- src/next_item_rules/prelude/criteria.jl | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/next_item_rules/prelude/abstract.jl b/src/next_item_rules/prelude/abstract.jl index 5f3c665..5679712 100644 --- a/src/next_item_rules/prelude/abstract.jl +++ b/src/next_item_rules/prelude/abstract.jl @@ -32,9 +32,9 @@ $(TYPEDEF) Abstract base type all criteria should inherit from """ abstract type CriterionBase <: CatConfigBase end -abstract type ItemCriterionBase <: CatConfigBase end +abstract type SubItemCriterionBase <: CatConfigBase end -abstract type ItemCriterion <: ItemCriterionBase end +abstract type ItemCriterion <: CatConfigBase end """ $(TYPEDEF) @@ -44,17 +44,17 @@ abstract type StateCriterion <: CriterionBase end """ $(TYPEDEF) """ -abstract type PointwiseItemCriterion <: ItemCriterionBase end +abstract type PointwiseItemCriterion <: SubItemCriterionBase end """ $(TYPEDEF) """ -abstract type ItemCategoryCriterion <: ItemCriterionBase end +abstract type ItemCategoryCriterion <: SubItemCriterionBase end """ $(TYPEDEF) """ -abstract type PointwiseItemCategoryCriterion <: ItemCriterionBase end +abstract type PointwiseItemCategoryCriterion <: SubItemCriterionBase end abstract type MatrixScalarizer end abstract type StateMultiCriterion end diff --git a/src/next_item_rules/prelude/criteria.jl b/src/next_item_rules/prelude/criteria.jl index 1e2055e..31dede9 100644 --- a/src/next_item_rules/prelude/criteria.jl +++ b/src/next_item_rules/prelude/criteria.jl @@ -90,7 +90,7 @@ function compute_criteria( end function compute_criterion( - ppic::ItemCriterionBase, tracked_responses::TrackedResponses, item_idx, args...) + ppic::SubItemCriterionBase, tracked_responses::TrackedResponses, item_idx, args...) compute_criterion(ppic, ItemResponse(tracked_responses.item_bank, item_idx), args...) end From 869d6997e99680aecc9a762627c7374204c37c7d Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 16:57:07 +0300 Subject: [PATCH 23/42] Fix up and modularise tests --- test/ability_estimator_2d.jl | 2 +- test/dummy.jl | 2 +- test/smoke.jl | 12 ++++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/test/ability_estimator_2d.jl b/test/ability_estimator_2d.jl index 6929bea..c87cc8b 100644 --- a/test/ability_estimator_2d.jl +++ b/test/ability_estimator_2d.jl @@ -67,7 +67,7 @@ mle_mode_2d = ModeAbilityEstimator(lh_est_2d, optimizer_2d) # Item closer to the current estimate (1, 1) close_item = 5 # Item further from the current estimate - far_item = 6 + far_item = 7 close_info = compute_criterion( information_criterion, tracked_responses_2d, close_item) diff --git a/test/dummy.jl b/test/dummy.jl index 95b3895..bb5b56e 100644 --- a/test/dummy.jl +++ b/test/dummy.jl @@ -34,7 +34,7 @@ const ability_estimators_1d = [ ((:integrator,), (stuff) -> MeanAbilityEstimator(LikelihoodAbilityEstimator(), stuff.integrator)), ((:optimizer,), - (stuff) -> ModeAbilityEstimator(LikelihoodAbilityEstimator(), stuff.optimizer)) + (stuff) -> ModeAbilityEstimator(SafeLikelihoodAbilityEstimator(), stuff.optimizer)) ] const criteria_1d = [ ((:integrator, :est), diff --git a/test/smoke.jl b/test/smoke.jl index 4a4c176..593d08a 100644 --- a/test/smoke.jl +++ b/test/smoke.jl @@ -1,5 +1,17 @@ #(item_bank, abilities, responses) = dummy_full(Random.default_rng(42), SimpleItemBankSpec(StdModel4PL(), VectorContinuousDomain(), BooleanResponse()), 2; num_questions=100, num_testees=3) +using Random +using ComputerAdaptiveTesting +using ComputerAdaptiveTesting.Aggregators +using ComputerAdaptiveTesting.TerminationConditions +using ComputerAdaptiveTesting.Sim +using FittedItemBanks +using FittedItemBanks.DummyData: dummy_full, SimpleItemBankSpec, StdModel3PL, + VectorContinuousDomain, BooleanResponse, std_normal + +include("./dummy.jl") +using .Dummy + @testset "Smoke test 1d" begin (item_bank, abilities, true_responses) = dummy_full( Random.default_rng(42), From 55525c8d69be0d8b8f5583f9533c14edfe1bd42f Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sun, 8 Jun 2025 17:07:45 +0300 Subject: [PATCH 24/42] Bump PsychometricsBazaarBase req to 0.8.4 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 63653fd..9eacf4f 100644 --- a/Project.toml +++ b/Project.toml @@ -60,7 +60,7 @@ MacroTools = "^0.5.6" Mmap = "^1.11" Optim = "1.7.3" PrecompileTools = "1.2.1" -PsychometricsBazaarBase = "^0.8.1" +PsychometricsBazaarBase = "^0.8.4" QuickHeaps = "0.2.2" Random = "^1.11" Reexport = "1" From 3bed63810a607bdb182f4db068e76885b93f6b50 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Tue, 17 Jun 2025 17:05:55 +0300 Subject: [PATCH 25/42] Fixes to randomesque --- src/next_item_rules/strategies/randomesque.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/next_item_rules/strategies/randomesque.jl b/src/next_item_rules/strategies/randomesque.jl index ec8f0dc..2d0a6bf 100644 --- a/src/next_item_rules/strategies/randomesque.jl +++ b/src/next_item_rules/strategies/randomesque.jl @@ -10,7 +10,7 @@ function randomesque( k::Int ) objective_state = init_thread(objective, responses) - heap = BinaryHeap{Node{Int, Float64}}(o = FastMax) + heap = BinaryHeap{Node{Int, Float64}}(FastMax) sizehint!(heap, k) for item_idx in eachindex(items) if (findfirst(idx -> idx == item_idx, responses.responses.indices) !== nothing) @@ -36,6 +36,7 @@ end $(TYPEDEF) $(TYPEDFIELDS) +http://dx.doi.org/10.1207/s15324818ame0204_6 """ struct RandomesqueStrategy <: NextItemStrategy rng::AbstractRNG @@ -49,5 +50,5 @@ function best_item( responses::TrackedResponses, items ) where {ItemCriterionT <: ItemCriterion} - randomesque(rule.rng, rule.criterion, responses, items, rule.k)[1] + randomesque(rule.strategy.rng, rule.criterion, responses, items, rule.strategy.k)[1] end \ No newline at end of file From 6dd9ce950325bb4057d04b4165500f30f9b0b253 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 21 Jun 2025 10:34:57 +0300 Subject: [PATCH 26/42] Add GreedyForcedContentBalancer and PointwiseNextItemRule --- Project.toml | 2 +- src/next_item_rules/NextItemRules.jl | 8 +- src/next_item_rules/prelude/next_item_rule.jl | 7 ++ src/next_item_rules/strategies/balance.jl | 90 +++++++++++++++++++ src/next_item_rules/strategies/exhaustive.jl | 29 +++--- src/next_item_rules/strategies/pointwise.jl | 21 +++++ src/next_item_rules/strategies/sequential.jl | 6 +- 7 files changed, 147 insertions(+), 16 deletions(-) create mode 100644 src/next_item_rules/strategies/balance.jl create mode 100644 src/next_item_rules/strategies/pointwise.jl diff --git a/Project.toml b/Project.toml index 9eacf4f..3e3e7af 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ Distributions = "^0.25.88" DocStringExtensions = " ^0.9" EffectSizes = "^1.0.1" FillArrays = "0.13, 1.5.0" -FittedItemBanks = "^0.6.3, ^0.7.0" +FittedItemBanks = "^0.7.2" ForwardDiff = "1" HypothesisTests = "^0.10.12, ^0.11.0" Interpolations = "^0.14, ^0.15" diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index a301968..efe0f78 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -21,10 +21,11 @@ using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome, find1_instance, find1_type using PsychometricsBazaarBase.Integrators: Integrator, intval using PsychometricsBazaarBase: Integrators +using PsychometricsBazaarBase.IndentWrappers: indent import PsychometricsBazaarBase.IntegralCoeffs using FittedItemBanks: AbstractItemBank, DiscreteDomain, DomainType, ItemResponse, OneDimContinuousDomain, domdims, item_params, - resp, resp_vec, responses + resp, resp_vec, responses, subset_view using ..Aggregators using ..Aggregators: covariance_matrix, FunctionProduct @@ -34,6 +35,7 @@ using Base.Order using StaticArrays: SVector using ConstructionBase: constructorof import ForwardDiff +import Base: show export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread export NextItemRule, ItemStrategyNextItemRule @@ -46,7 +48,7 @@ export EmpiricalInformationPointwiseItemCategoryCriterion export TotalItemInformation export RandomNextItemRule export PiecewiseNextItemRule, MemoryNextItemRule, FixedFirstItemNextItemRule -export ExhaustiveSearch +export ExhaustiveSearch, RandomesqueStrategy export preallocate export compute_criteria, compute_criterion, compute_multi_criterion export best_item @@ -68,6 +70,8 @@ include("./strategies/random.jl") include("./strategies/randomesque.jl") include("./strategies/sequential.jl") include("./strategies/exhaustive.jl") +include("./strategies/pointwise.jl") +include("./strategies/balance.jl") # Combinators include("./combinators/expectation.jl") diff --git a/src/next_item_rules/prelude/next_item_rule.jl b/src/next_item_rules/prelude/next_item_rule.jl index a5bc007..f0311f2 100644 --- a/src/next_item_rules/prelude/next_item_rule.jl +++ b/src/next_item_rules/prelude/next_item_rule.jl @@ -53,6 +53,13 @@ function best_item(rule::NextItemRule, tracked_responses::TrackedResponses) best_item(rule, tracked_responses, tracked_responses.item_bank) end +function Base.show(io::IO, ::MIME"text/plain", next_item_rule::ItemStrategyNextItemRule) + println(io, "Strategy:") + show(indent_io, MIME"text/plain"(), rules.strategy) + println(io, "Item criterion:") + show(indent_io, MIME"text/plain"(), rules.criterion) +end + # Default implementation function compute_criteria(::NextItemRule, ::TrackedResponses) nothing diff --git a/src/next_item_rules/strategies/balance.jl b/src/next_item_rules/strategies/balance.jl new file mode 100644 index 0000000..2d3e6cd --- /dev/null +++ b/src/next_item_rules/strategies/balance.jl @@ -0,0 +1,90 @@ +""" +$(TYPEDEF) +$(TYPEDFIELDS) + +This content balancing procedure takes target proportions for each group of items. +At each step the group with the lowest ratio of seen items to target is selected. + +http://dx.doi.org/10.1207/s15324818ame0403_4 +""" +struct GreedyForcedContentBalancer{InnerRuleT <: NextItemRule} <: NextItemRule + targets::Vector{Float64} + groups::Vector{Int} + inner_rule::InnerRuleT +end + +function GreedyForcedContentBalancer(targets::Dict, groups, bits...) + targets_vec = zeros(Float64, length(targets)) + groups_idxs = zeros(Int, length(groups)) + group_lookup = Dict{Any, Int}() + for (idx, group) in enumerate(groups) + if haskey(group_lookup, group) + group_idx = group_lookup[group] + else + group_idx = length(group_lookup) + 1 + group_lookup[group] = group_idx + end + groups_idxs[idx] = group_idx + end + if length(group_lookup) != length(targets) + error("Number of groups $(length(group_lookup)) does not match number of targets $(length(targets))") + end + for (group, group_idx) in pairs(group_lookup) + targets_vec[group_idx] = get(targets, group, 0.0) + end + GreedyForcedContentBalancer(targets_vec, groups_idxs, bits...) +end + +function GreedyForcedContentBalancer(targets::AbstractVector, groups, bits...) + GreedyForcedContentBalancer(targets, groups, NextItemRule(bits...)) +end + +function show(io::IO, ::MIME"text/plain", rule::GreedyForcedContentBalancer) + indent_io = indent(io, 2) + println(io, "Greedy + forced content balancer") + println(indent_io, "Target ratio: " * join(rule.targets, ", ")) + print(indent_io, "Using rule: ") + show(indent_io, MIME("text/plain"), rule.inner_rule) +end + +function next_item_bank(targets, groups, responses, items) + seen = zeros(UInt, size(targets)) + indices = responses.responses.indices + for group_idx in groups[indices] + seen[group_idx] += 1 + end + next_group_idx = argmin(seen ./ targets) + matching_indicator = groups .== next_group_idx + next_items = subset_view(items, matching_indicator) + return (next_items, matching_indicator) +end + +function best_item( + rule::GreedyForcedContentBalancer, + responses::TrackedResponses, + items +) + next_items, matching_indicator = next_item_bank(rule.targets, rule.groups, responses, items) + inner_idx = best_item(rule.inner_rule, responses, next_items) + for (outer_idx, in_group) in enumerate(matching_indicator) + if in_group + inner_idx -= 1 + if inner_idx <= 0 + return outer_idx + end + end + end + error("No item found in group length $(length(next_items)) with inner index $inner_idx") +end + +function compute_criteria( + rule::GreedyForcedContentBalancer, + responses::TrackedResponses, + items +) + next_items, matching_indicator = next_item_bank(rule.targets, rule.groups, responses, items) + criteria = compute_criteria(rule.inner_rule, responses, next_items) + expanded = fill(Inf, length(items)) + expanded[matching_indicator] .= criteria + return expanded +end \ No newline at end of file diff --git a/src/next_item_rules/strategies/exhaustive.jl b/src/next_item_rules/strategies/exhaustive.jl index 4c13255..849b91b 100644 --- a/src/next_item_rules/strategies/exhaustive.jl +++ b/src/next_item_rules/strategies/exhaustive.jl @@ -1,21 +1,18 @@ -function exhaustive_search(objective::ItemCriterionT, - responses::TrackedResponseT, - items::AbstractItemBank)::Tuple{ - Int, - Float64 -} where {ItemCriterionT <: ItemCriterion, TrackedResponseT <: TrackedResponses} - #pre_next_item(expectation_tracker, items) - objective_state = init_thread(objective, responses) +function exhaustive_search( + callback, + answered_items::AbstractVector{Int}, + items::AbstractItemBank +)::Tuple{Int, Float64} min_obj_idx::Int = -1 min_obj_val::Float64 = Inf for item_idx in eachindex(items) # TODO: Add these back in #@init irf_states_storage = zeros(Int, length(responses) + 1) - if (findfirst(idx -> idx == item_idx, responses.responses.indices) !== nothing) + if (findfirst(idx -> idx == item_idx, answered_items) !== nothing) continue end - obj_val = compute_criterion(objective, objective_state, responses, item_idx) + obj_val = callback(item_idx) if obj_val <= min_obj_val min_obj_val = obj_val @@ -25,6 +22,18 @@ function exhaustive_search(objective::ItemCriterionT, return (min_obj_idx, min_obj_val) end +function exhaustive_search(objective::ItemCriterionT, + responses::TrackedResponseT, + items::AbstractItemBank)::Tuple{ + Int, + Float64 +} where {ItemCriterionT <: ItemCriterion, TrackedResponseT <: TrackedResponses} + objective_state = init_thread(objective, responses) + return exhaustive_search(responses.responses.indices, items) do item_idx + return compute_criterion(objective, objective_state, responses, item_idx) + end +end + """ $(TYPEDEF) $(TYPEDFIELDS) diff --git a/src/next_item_rules/strategies/pointwise.jl b/src/next_item_rules/strategies/pointwise.jl new file mode 100644 index 0000000..80c0709 --- /dev/null +++ b/src/next_item_rules/strategies/pointwise.jl @@ -0,0 +1,21 @@ +struct PointwiseNextItemRule{CriterionT <: PointwiseItemCriterion, PointsT <: AbstractArray{<:Number}} <: NextItemRule + criterion::CriterionT + points::PointsT +end + +function best_item(rule::PointwiseNextItemRule, responses::TrackedResponses, items) + num_responses = length(responses.responses.indices) + next_index = num_responses + 1 + if next_index > length(rule.points) + error("Number of responses exceeds the number of points defined in the rule.") + end + current_point = rule.points[next_index] + idx, _ = exhaustive_search(responses.responses.indices, items) do item_idx + return compute_criterion(rule.criterion, ItemResponse(items, item_idx), current_point) + end + return idx +end + +function PointwiseFirstNextItemRule(criterion, points, rule) + PiecewiseNextItemRule((length(points),), (PointwiseNextItemRule(criterion, points), rule)) +end \ No newline at end of file diff --git a/src/next_item_rules/strategies/sequential.jl b/src/next_item_rules/strategies/sequential.jl index 12676b2..845e3ae 100644 --- a/src/next_item_rules/strategies/sequential.jl +++ b/src/next_item_rules/strategies/sequential.jl @@ -6,10 +6,10 @@ This is the most basic rule for choosing the next item in a CAT. It simply picks a random item from the set of items that have not yet been administered. """ -@kwdef struct PiecewiseNextItemRule{BreaksT, RulesT} <: NextItemRule +@kwdef struct PiecewiseNextItemRule{RulesT} <: NextItemRule # Tuple of Ints - breaks::BreaksT - # Types of NextItemRules + breaks::Tuple{Int} + # Tuple of NextItemRules rules::RulesT end From d9151d99be32bb35c5d70694ca2e4c67ecfa0bc3 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 21 Jun 2025 10:44:21 +0300 Subject: [PATCH 27/42] Mark some stuff as scalars for broadcasting --- src/aggregators/Aggregators.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/aggregators/Aggregators.jl b/src/aggregators/Aggregators.jl index e2578eb..bd302df 100644 --- a/src/aggregators/Aggregators.jl +++ b/src/aggregators/Aggregators.jl @@ -71,6 +71,9 @@ function AbilityEstimator(::ContinuousDomain, bits...) integrator) end +# Mark as a scalar for broadcasting +Base.broadcastable(ir::AbilityEstimator) = Ref(ir) + abstract type DistributionAbilityEstimator <: AbilityEstimator end function DistributionAbilityEstimator(bits...) @returnsome find1_instance(DistributionAbilityEstimator, bits) @@ -171,6 +174,9 @@ function TrackedResponses(responses, item_bank) TrackedResponses(responses, item_bank, NullAbilityTracker()) end +# Mark as a scalar for broadcasting +Base.broadcastable(ir::TrackedResponses) = Ref(ir) + function Responses.AbilityLikelihood(tracked_responses::TrackedResponses{ BareResponsesT, ItemBankT, From cf02e1146ad4c498255106f22aefadb70f61880e Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 21 Jun 2025 10:46:11 +0300 Subject: [PATCH 28/42] Additions and refactor * Add in CatRecorder * CatLoopConfig => CatLoop * Move CatLoop into Sim --- Project.toml | 2 + .../examples/ability_convergence_3pl.jl | 2 +- .../examples/ability_convergence_mirt.jl | 2 +- docs/examples/examples/vocab_iq.jl | 2 +- docs/src/api.md | 2 +- docs/src/creating_a_cat.md | 4 +- docs/src/stateful.md | 2 +- docs/src/using_your_cat.md | 4 +- src/Compat/CatR.jl | 2 +- src/Compat/MirtCAT.jl | 2 +- src/ComputerAdaptiveTesting.jl | 10 +- src/{CatConfig.jl => Rules.jl} | 47 +-- src/Stateful.jl | 6 +- src/TerminationConditions.jl | 9 +- src/decision_tree/sim.jl | 4 +- src/precompiles.jl | 3 +- src/sim/Sim.jl | 98 +---- src/sim/loop.jl | 54 +++ src/sim/recorder.jl | 365 ++++++++++++++++++ src/sim/run.jl | 84 ++++ test/dt.jl | 6 +- test/runtests.jl | 3 +- test/smoke.jl | 2 +- 23 files changed, 566 insertions(+), 149 deletions(-) rename src/{CatConfig.jl => Rules.jl} (79%) create mode 100644 src/sim/loop.jl create mode 100644 src/sim/recorder.jl create mode 100644 src/sim/run.jl diff --git a/Project.toml b/Project.toml index 3e3e7af..3c91918 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EffectSizes = "e248de7e-9197-5860-972e-353a2af44d75" +ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FittedItemBanks = "3f797b09-34e4-41d7-acf6-3302ae3248a5" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -47,6 +48,7 @@ DataFrames = "1.6.1" Distributions = "^0.25.88" DocStringExtensions = " ^0.9" EffectSizes = "^1.0.1" +ElasticArrays = "1.2.12" FillArrays = "0.13, 1.5.0" FittedItemBanks = "^0.7.2" ForwardDiff = "1" diff --git a/docs/examples/examples/ability_convergence_3pl.jl b/docs/examples/examples/ability_convergence_3pl.jl index cb3608e..eefc264 100644 --- a/docs/examples/examples/ability_convergence_3pl.jl +++ b/docs/examples/examples/ability_convergence_3pl.jl @@ -57,7 +57,7 @@ xs = range(-2.5, 2.5, length = points) raw_estimator = LikelihoodAbilityEstimator() recorder = CatRecorder(xs, responses, integrator, raw_estimator, ability_estimator) for testee_idx in axes(responses, 2) - tracked_responses, θ = run_cat(CatLoopConfig(rules = rules, + tracked_responses, θ = run_cat(CatLoop(rules = rules, get_response = auto_responder(@view responses[:, testee_idx]), new_response_callback = (tracked_responses, terminating) -> recorder(tracked_responses, testee_idx, diff --git a/docs/examples/examples/ability_convergence_mirt.jl b/docs/examples/examples/ability_convergence_mirt.jl index d71ac95..e481b87 100644 --- a/docs/examples/examples/ability_convergence_mirt.jl +++ b/docs/examples/examples/ability_convergence_mirt.jl @@ -67,7 +67,7 @@ recorder = CatRecorder(xs, abilities) for testee_idx in axes(responses, 2) @debug "Running for testee" testee_idx - tracked_responses, θ = run_cat(CatLoopConfig(rules = rules, + tracked_responses, θ = run_cat(CatLoop(rules = rules, get_response = auto_responder(@view responses[:, testee_idx]), new_response_callback = (tracked_responses, terminating) -> recorder(tracked_responses, testee_idx, diff --git a/docs/examples/examples/vocab_iq.jl b/docs/examples/examples/vocab_iq.jl index 01bad7e..4c67479 100644 --- a/docs/examples/examples/vocab_iq.jl +++ b/docs/examples/examples/vocab_iq.jl @@ -63,7 +63,7 @@ function run_vocab_iq_cat() println("Got ability estimate: $ability ± $var") println("") end - loop_config = CatLoopConfig(rules = rules, + loop_config = CatLoop(rules = rules, get_response = get_response, new_response_callback = new_response_callback) run_cat(loop_config, item_bank) diff --git a/docs/src/api.md b/docs/src/api.md index eb2a274..8755a67 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,5 +8,5 @@ CurrentModule = ComputerAdaptiveTesting ``` ```@autodocs -Modules = [ComputerAdaptiveTesting, ComputerAdaptiveTesting.Aggregators, ComputerAdaptiveTesting.Responses, ComputerAdaptiveTesting.Sim, ComputerAdaptiveTesting.TerminationConditions, ComputerAdaptiveTesting.NextItemRules, ComputerAdaptiveTesting.CatConfig] +Modules = [ComputerAdaptiveTesting, ComputerAdaptiveTesting.Aggregators, ComputerAdaptiveTesting.Responses, ComputerAdaptiveTesting.Sim, ComputerAdaptiveTesting.TerminationConditions, ComputerAdaptiveTesting.NextItemRules, ComputerAdaptiveTesting.Rules] ``` diff --git a/docs/src/creating_a_cat.md b/docs/src/creating_a_cat.md index 8b2664b..ba30b89 100644 --- a/docs/src/creating_a_cat.md +++ b/docs/src/creating_a_cat.md @@ -13,7 +13,7 @@ The configuration of a CAT is built up as a tree of configuration structs. These structs are all subtypes of `CatConfigBase`. ```@docs; canonical=false -ComputerAdaptiveTesting.CatConfig.CatConfigBase +ComputerAdaptiveTesting.ConfigBase.CatConfigBase ``` The constructors for the configuration structs in this package tend to have @@ -59,7 +59,7 @@ next item selection rule, and the stopping rule. `CatRules` has explicit and implicit constructors. ```@docs; canonical=false -ComputerAdaptiveTesting.CatConfig.CatRules +ComputerAdaptiveTesting.CatRules ``` ### Next item selection with `NextItemRule` diff --git a/docs/src/stateful.md b/docs/src/stateful.md index 2c62260..aaa24eb 100644 --- a/docs/src/stateful.md +++ b/docs/src/stateful.md @@ -33,4 +33,4 @@ Stateful.StatefulCatConfig ## Usage -Just as [CatLoopConfig](@ref) can wrap [CatRules](@ref), you can also use it with any implementor of [Stateful.StatefulCat](@ref), and run using [Sim.run_cat](@ref). \ No newline at end of file +Just as [CatLoop](@ref) can wrap [CatRules](@ref), you can also use it with any implementor of [Stateful.StatefulCat](@ref), and run using [Sim.run_cat](@ref). \ No newline at end of file diff --git a/docs/src/using_your_cat.md b/docs/src/using_your_cat.md index dd33d4d..04711c1 100644 --- a/docs/src/using_your_cat.md +++ b/docs/src/using_your_cat.md @@ -9,10 +9,10 @@ a number of ways you can use it. This section covers a few. See also the [Examples](@ref demo-page). -When you've set up your CAT using [CatRules](@ref), you can wrap it in a [CatLoopConfig](@ref) and run it with [run_cat](@ref). +When you've set up your CAT using [CatRules](@ref), you can wrap it in a [CatLoop](@ref) and run it with [run_cat](@ref). ```@docs; canonical=false -CatLoopConfig +CatLoop run_cat ``` diff --git a/src/Compat/CatR.jl b/src/Compat/CatR.jl index 1314627..6af9eb9 100644 --- a/src/Compat/CatR.jl +++ b/src/Compat/CatR.jl @@ -7,7 +7,7 @@ using ComputerAdaptiveTesting.Aggregators: AbilityIntegrator, MeanAbilityEstimator, PriorAbilityEstimator using ComputerAdaptiveTesting.TerminationConditions: RunForeverTerminationCondition -using ComputerAdaptiveTesting.CatConfig: CatRules +using ComputerAdaptiveTesting.Rules: CatRules using ComputerAdaptiveTesting.NextItemRules using PsychometricsBazaarBase: Integrators, Optimizers diff --git a/src/Compat/MirtCAT.jl b/src/Compat/MirtCAT.jl index 24fd5d0..97706e7 100644 --- a/src/Compat/MirtCAT.jl +++ b/src/Compat/MirtCAT.jl @@ -9,8 +9,8 @@ using ComputerAdaptiveTesting.Aggregators: SafeLikelihoodAbilityEstimator, AbilityEstimator, distribution_estimator using ComputerAdaptiveTesting.TerminationConditions: RunForeverTerminationCondition -using ComputerAdaptiveTesting.CatConfig: CatRules using ComputerAdaptiveTesting.NextItemRules +using ComputerAdaptiveTesting: CatRules using PsychometricsBazaarBase: Integrators, Optimizers public next_item_aliases, ability_estimator_aliases, assemble_rules diff --git a/src/ComputerAdaptiveTesting.jl b/src/ComputerAdaptiveTesting.jl index 1365275..033ba97 100644 --- a/src/ComputerAdaptiveTesting.jl +++ b/src/ComputerAdaptiveTesting.jl @@ -5,9 +5,9 @@ include("./hacks.jl") using Reexport: Reexport, @reexport # Modules -export ConfigBase, Responses, Aggregators +export Responses, Aggregators export NextItemRules, TerminationConditions -export CatConfig, Sim, DecisionTree +export Sim, DecisionTree export Stateful, Comparison # Extension modules @@ -33,7 +33,7 @@ include("./next_item_rules/NextItemRules.jl") include("./TerminationConditions.jl") # Combining / running -include("./CatConfig.jl") +include("./Rules.jl") include("./sim/Sim.jl") include("./decision_tree/DecisionTree.jl") @@ -42,8 +42,8 @@ include("./Stateful.jl") include("./Compat/Compat.jl") include("./Comparison/Comparison.jl") -@reexport using .CatConfig: CatLoopConfig, CatRules -@reexport using .Sim: run_cat +@reexport using .Rules: CatRules +@reexport using .Sim: CatLoop, run_cat @reexport using .NextItemRules: preallocate include("./precompiles.jl") diff --git a/src/CatConfig.jl b/src/Rules.jl similarity index 79% rename from src/CatConfig.jl rename to src/Rules.jl index e809f39..40e7f34 100644 --- a/src/CatConfig.jl +++ b/src/Rules.jl @@ -1,15 +1,17 @@ -module CatConfig +module Rules -export CatRules, CatLoopConfig +export CatRules using DocStringExtensions using PsychometricsBazaarBase.Parameters +using PsychometricsBazaarBase.IndentWrappers: indent using ..Aggregators: AbilityEstimator, AbilityTracker, ConsAbilityTracker, NullAbilityTracker using ..NextItemRules: NextItemRule using ..TerminationConditions: TerminationCondition using ..ConfigBase +import Base: show """ $(TYPEDEF) @@ -19,7 +21,7 @@ Configuration of the rules for a CAT. This all includes all the basic rules for the CAT's operation, but not the item bank, nor any of the interactivity hooks needed to actually run the CAT. -This may be more a more convenient layer to integrate than CatLoopConfig if you +This may be more a more convenient layer to integrate than CatLoop if you want to write your own CAT loop rather than using hooks. $(FUNCTIONNAME)(; next_item=..., termination_condition=..., ability_estimator=..., ability_tracker=...) @@ -79,6 +81,16 @@ function CatRules(bits...) ability_tracker = collect_trackers(next_item, ability_tracker)) end +function show(io::IO, ::MIME"text/plain", rules::CatRules) + indent_io = indent(io, 2) + println(io, "Next item rule:") + show(indent_io, MIME"text/plain"(), rules.next_item) + println(io, "Termination condition:") + show(indent_io, MIME"text/plain"(), rules.termination_condition) + println(io, "Ability estimator:") + show(indent_io, MIME"text/plain"(), rules.ability_estimator) +end + function _find_ability_estimator_and_tracker(bits...) ability_estimator = AbilityEstimator(bits...) ability_tracker = AbilityTracker(bits...; ability_estimator = ability_estimator) @@ -113,33 +125,4 @@ function collect_trackers(next_item_rule::NextItemRule, ability_tracker::Ability end end -""" -```julia -struct $(FUNCTIONNAME) -$(FUNCTIONNAME)(; rules=..., get_response=..., new_response_callback=...) -``` -$(TYPEDFIELDS) - -Configuration for a simulatable CAT. -""" -@with_kw struct CatLoopConfig{CatEngineT} <: CatConfigBase - """ - An object which implements the CAT engine. - Implementations exist for: - * [CatRules](@ref) - * [Stateful.StatefulCat](@ref ComputerAdaptiveTesting.Stateful.StatefulCat) - """ - rules::CatEngineT # e.g. CatRules - """ - The function `(index, label) -> Int8`` which obtains the testee's response for - a given question, e.g. by prompting or simulation from data. - """ - get_response::Any - """ - A callback called each time there is a new responses. - If provided, it is passed `(responses::TrackedResponses, terminating)`. - """ - new_response_callback = nothing -end - end diff --git a/src/Stateful.jl b/src/Stateful.jl index 38450a0..589651f 100644 --- a/src/Stateful.jl +++ b/src/Stateful.jl @@ -9,10 +9,10 @@ using DocStringExtensions using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, resp_vec using ..Aggregators: TrackedResponses, Aggregators, pdf, distribution_estimator -using ..CatConfig: CatLoopConfig, CatRules +using ..Rules: CatRules using ..Responses: BareResponses, Response, Responses using ..NextItemRules: compute_criteria, best_item -using ..Sim: Sim, item_label +using ..Sim: CatLoop, Sim, item_label export StatefulCat, StatefulCatConfig public next_item, ranked_items, item_criteria @@ -156,7 +156,7 @@ model backing the CAT. function item_response_functions end ## Running the CAT -function Sim.run_cat(cat_config::CatLoopConfig{RulesT}, +function Sim.run_cat(cat_config::CatLoop{RulesT}, ib_labels = nothing) where {RulesT <: StatefulCat} (; stateful_cat, get_response, new_response_callback) = cat_config while true diff --git a/src/TerminationConditions.jl b/src/TerminationConditions.jl index f0b5261..5c99c69 100644 --- a/src/TerminationConditions.jl +++ b/src/TerminationConditions.jl @@ -8,7 +8,7 @@ using PsychometricsBazaarBase.ConfigTools: @returnsome, find1_instance using FittedItemBanks export TerminationCondition, - FixedItemsTerminationCondition, SimpleFunctionTerminationCondition + LengthTerminationCondition, SimpleFunctionTerminationCondition export RunForeverTerminationCondition """ @@ -24,14 +24,17 @@ end $(TYPEDEF) $(TYPEDFIELDS) """ -struct FixedItemsTerminationCondition{} <: TerminationCondition +struct LengthTerminationCondition{} <: TerminationCondition num_items::Int64 end -function (condition::FixedItemsTerminationCondition)(responses::TrackedResponses, +function (condition::LengthTerminationCondition)(responses::TrackedResponses, items::AbstractItemBank) length(responses) >= condition.num_items end +# Alias for old name +const FixedItemsTerminationCondition = LengthTerminationCondition + struct SimpleFunctionTerminationCondition{F} <: TerminationCondition func::F end diff --git a/src/decision_tree/sim.jl b/src/decision_tree/sim.jl index 73b9db5..5d67f84 100644 --- a/src/decision_tree/sim.jl +++ b/src/decision_tree/sim.jl @@ -1,9 +1,9 @@ import ComputerAdaptiveTesting: Sim """ -Run a given CatLoopConfig with a MaterializedDecisionTree +Run a given CatLoop with a MaterializedDecisionTree """ -function Sim.run_cat(cat_config::Sim.CatLoopConfig{RulesT}, +function Sim.run_cat(cat_config::Sim.CatLoop{RulesT}, item_bank::AbstractItemBank; ib_labels = nothing) where {RulesT <: MaterializedDecisionTree} (; rules, get_response, new_response_callback) = cat_config diff --git a/src/precompiles.jl b/src/precompiles.jl index 3225c8f..5147e9c 100644 --- a/src/precompiles.jl +++ b/src/precompiles.jl @@ -10,6 +10,7 @@ using PrecompileTools: @compile_workload, @setup_workload using .NextItemRules: preallocate, ExhaustiveSearch, ItemStrategyNextItemRule, ExpectationBasedItemCriterion, AbilityVarianceStateCriterion using .Stateful: Stateful + using .ComputerAdaptiveTesting: CatRules rng = default_rng(42) spec = SimpleItemBankSpec(StdModel2PL(), OneDimContinuousDomain(), BooleanResponse()) @@ -24,7 +25,7 @@ using PrecompileTools: @compile_workload, @setup_workload ExhaustiveSearch(), ExpectationBasedItemCriterion(ability_estimator, AbilityVarianceStateCriterion(ability_estimator))) - cat = Stateful.StatefulCatConfig(CatConfig.CatRules(; + cat = Stateful.StatefulCatConfig(CatRules(; next_item=next_item_rule, termination_condition=TerminationConditions.RunForeverTerminationCondition(), ability_estimator=ability_estimator diff --git a/src/sim/Sim.jl b/src/sim/Sim.jl index a01ae6e..5cfe830 100644 --- a/src/sim/Sim.jl +++ b/src/sim/Sim.jl @@ -1,10 +1,14 @@ module Sim +using ElasticArrays +using ElasticArrays: sizehint_lastdim! using DocStringExtensions using StatsBase -using FittedItemBanks: AbstractItemBank, ResponseType +using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse +using PsychometricsBazaarBase.Integrators +using ..ConfigBase using ..Responses -using ..CatConfig: CatLoopConfig, CatRules +using ..Rules: CatRules using ..Aggregators: TrackedResponses, add_response!, Aggregators, @@ -17,92 +21,14 @@ using ..Aggregators: TrackedResponses, LikelihoodAbilityEstimator, RiemannEnumerationIntegrator using ..NextItemRules: compute_criteria, best_item +import Base: show +export CatRecorder, CatRecording +export CatLoop, record! export run_cat, prompt_response, auto_responder -""" -$(TYPEDSIGNATURES) - -This response callback simply prompts the user for the response using the console -""" -function prompt_response(index_, label) - println("Response for $label > ") - parse(Int8, readline()) -end - -""" -$(TYPEDSIGNATURES) - -This function constructs a next item function which automatically responds -according to `responses`. -""" -function auto_responder(responses) - function (index, label_) - responses[index] - end -end - -abstract type NextItemError <: Exception end - -function item_label(ib_labels, next_index) - default_next_label(next_index) = "<>" - if ib_labels === nothing - return default_next_label(next_index) - else - return get(default_next_label, ib_labels, next_index) - end -end - -""" -```julia -$(FUNCTIONNAME)(cat_config::CatLoopConfig, item_bank::AbstractItemBank; ib_labels=nothing) -``` - -Run a given [CatLoopConfig](@ref) `cat_config` on the given `item_bank`. -If `ib_labels` is not given, default labels of the form -`<>` are passed to the callback. -""" -function run_cat(cat_config::CatLoopConfig{RulesT}, - item_bank::AbstractItemBank; - ib_labels = nothing) where {RulesT <: CatRules} - (; rules, get_response, new_response_callback) = cat_config - (; next_item, termination_condition, ability_estimator, ability_tracker) = rules - responses = TrackedResponses(BareResponses(ResponseType(item_bank)), - item_bank, - ability_tracker) - while true - local next_index - @debug begin - criteria = compute_criteria(next_item, responses, item_bank) - "Best items" - end criteria - try - next_index = best_item(next_item, responses, item_bank) - catch exc - if isa(exc, NextItemError) - @warn "Terminating early due to error getting next item" err=sprint( - showerror, - exc) - break - else - rethrow() - end - end - next_label = item_label(ib_labels, next_index) - @debug "Querying" next_label - response = get_response(next_index, next_label) - @debug "Got response" response - add_response!(responses, Response(ResponseType(item_bank), next_index, response)) - terminating = termination_condition(responses, item_bank) - if new_response_callback !== nothing - new_response_callback(responses, terminating) - end - if terminating - @debug "Met termination condition" - break - end - end - (responses.responses, ability_estimator(responses)) -end +include("./recorder.jl") +include("./loop.jl") +include("./run.jl") end diff --git a/src/sim/loop.jl b/src/sim/loop.jl new file mode 100644 index 0000000..629968c --- /dev/null +++ b/src/sim/loop.jl @@ -0,0 +1,54 @@ +""" +```julia +struct $(FUNCTIONNAME) +$(FUNCTIONNAME)(; rules=..., get_response=..., new_response_callback=...) +``` +$(TYPEDFIELDS) + +Configuration for a simulatable CAT. +""" +struct CatLoop{CatEngineT} <: CatConfigBase + """ + An object which implements the CAT engine. + Implementations exist for: + * [CatRules](@ref) + * [Stateful.StatefulCat](@ref ComputerAdaptiveTesting.Stateful.StatefulCat) + """ + rules::CatEngineT # e.g. CatRules + """ + The function `(index, label) -> Int8`` which obtains the testee's response for + a given question, e.g. by prompting or simulation from data. + """ + get_response::Any + """ + A callback called each time there is a new responses. + If provided, it is passed `(responses::TrackedResponses, terminating)`. + """ + new_response_callback +end + +function CatLoop(; + rules, + get_response, + new_response_callback = nothing, + new_response_callbacks = Any[], + recorder = nothing +) + new_response_callbacks = collect(new_response_callbacks) + if new_response_callback !== nothing + push!(new_response_callbacks, new_response_callback) + end + if recorder !== nothing && showable(MIME("text/plain"), rules) + buf = IOBuffer() + show(buf, MIME("text/plain"), rules) + recorder.recording.rules_description = String(take!(buf)) + push!(new_response_callbacks, catrecorder_callback(recorder)) + end + function all_callbacks(responses, terminating) + for callback in new_response_callbacks + callback(responses, terminating) + end + nothing + end + CatLoop{typeof(rules)}(rules, get_response, all_callbacks) +end \ No newline at end of file diff --git a/src/sim/recorder.jl b/src/sim/recorder.jl new file mode 100644 index 0000000..2b017e9 --- /dev/null +++ b/src/sim/recorder.jl @@ -0,0 +1,365 @@ +function empty_capacity(typ, size) + ret = typ[] + sizehint!(ret, size) + return ret +end + +function empty_capacity(typ, dims...) + ret = ElasticArray{typ}(undef, dims[1:end - 1]..., 0) + sizehint_lastdim!(ret, dims[end]) + return ret +end + +# Elastic arrays do not support `push!` directly, so we define our own +elastic_push!(xs::AbstractVector, value) = push!(xs, value) +elastic_push!(xs::ElasticArray, value) = append!(xs, value) + +Base.@kwdef mutable struct CatRecording{LikelihoodsT <: NamedTuple} + #ability_ests::AbilityVecT + #xs::Union{Nothing, AbilityVecT} + #likelihoods::Matrix{Float64} + #raw_likelihoods::Matrix{Float64} + data::LikelihoodsT + item_responses::Vector{Float64} + item_index::Vector{Int} + item_correctness::Vector{Bool} + rules_description::Union{Nothing, String} = nothing +end + +Base.@kwdef struct CatRecorder{RequestsT <: NamedTuple, LikelihoodsT <: NamedTuple} + recording::CatRecording{LikelihoodsT} + requests::RequestsT + #integrator::AbilityIntegrator + #raw_estimator::LikelihoodAbilityEstimator + #ability_estimator::AbilityEstimator +end + +function CatRecording( + data, + expected_responses=0 +) + CatRecording(; + data=data, + item_responses=empty_capacity(Float64, expected_responses), + item_index=empty_capacity(Int, expected_responses), + item_correctness=empty_capacity(Bool, expected_responses) + ) +end + +function show(io::IO, ::MIME"text/plain", recording::CatRecording) + println(io, "Recording of a Computer-Adaptive Test") + if recording.rules_description === nothing + println(io, " Unknown CAT configuration") + else + println(io, " CAT configuration:") + for line in split(recording.rules_description, "\n") + println(io, " ", line) + end + end + println(io, " item_responses: ", length(recording.item_responses)) + println(io, " item_index: ", length(recording.item_index)) + println(io, " item_correctness: ", length(recording.item_correctness)) + for (name, data) in pairs(recording.data) + println(io, " $name: ", size(data.data)) + end +end + +#= +function CatRecording( + xs, + points, + ability_ests, + num_questions, + num_respondents, + actual_abilities = nothing) + num_values = num_questions * num_respondents + if xs === nothing + xs_vec = nothing + else + xs_vec = collect(xs) + end + + CatRecorder(1, + 1, + points, + zeros(Int, num_values), + ability_ests, + zeros(Float64, num_values), + zeros(Int, num_values), + xs_vec, + zeros(points, num_values), + zeros(points, num_values), + zeros(points, num_values), + zeros(num_questions, num_respondents), + zeros(Int, num_questions, num_respondents), + zeros(Bool, num_questions, num_respondents), + Dict{Tuple{Int, Int}, Int}(), + actual_abilities) +end +=# + +function record!(recording::CatRecording, responses; data...) + #push_ability_est!(recording.ability_ests, recording.col_idx, ability_est) + + item_index = responses.indices[end] + item_correct = responses.values[end] > 0 + push!(recording.item_index, item_index) + push!(recording.item_correctness, item_correct) +end + +#= +""" +$(TYPEDSIGNATURES) +""" +function CatRecorder( + xs, + points, + ability_ests, + num_questions, + num_respondents, + integrator, + raw_estimator, + ability_estimator, + actual_abilities = nothing) + CatRecorder( + CatRecording( + xs, + points, + ability_ests, + num_questions, + num_respondents, + actual_abilities + ), + AbilityIntegrator(integrator), + raw_estimator, + ability_estimator, + ) +end + +function CatRecorder( + xs::AbstractVector{Float64}, + responses, + integrator, + raw_estimator, + ability_estimator, + actual_abilities = nothing + ) + points = size(xs, 1) + num_questions = size(responses, 1) + num_respondents = size(responses, 2) + num_values = num_questions * num_respondents + CatRecorder( + xs, + points, + zeros(num_values), + num_questions, + num_respondents, + integrator, + raw_estimator, + ability_estimator, + actual_abilities) +end + +function CatRecorder( + xs::AbstractMatrix{Float64}, + responses, + integrator, + raw_estimator, + ability_estimator, + actual_abilities = nothing + ) + dims = size(xs, 1) + points = size(xs, 2) + num_questions = size(responses, 1) + num_respondents = size(responses, 2) + num_values = num_questions * num_respondents + CatRecorder(xs, + points, + zeros(dims, num_values), + num_questions, + num_respondents, + integrator, + raw_estimator, + ability_estimator, + actual_abilities) +end + +function CatRecorder( + xs::AbstractVector{Float64}, + max_responses::Int, + integrator, + raw_estimator, + ability_estimator, + actual_abilities = nothing + ) + points = size(xs, 1) + CatRecorder(xs, + points, + zeros(max_responses), + max_responses, + 1, + integrator, + raw_estimator, + ability_estimator, + actual_abilities) +end + +function CatRecorder( + xs::AbstractMatrix{Float64}, + max_responses::Int, + integrator, + raw_estimator, + ability_estimator, + actual_abilities = nothing + ) + dims = size(xs, 1) + points = size(xs, 2) + CatRecorder(xs, + points, + zeros(dims, max_responses), + max_responses, + 1, + integrator, + raw_estimator, + ability_estimator, + actual_abilities) +end +=# + +function CatRecorder(dims::Int, expected_responses::Int; requests...) + out = [] + sizehint!(out, length(requests)) + for (name, request) in pairs(requests) + if request.type == :ability_value + data = empty_capacity(Float64, expected_responses) + elseif request.type == :ability_distribution + if dims == 0 + data = empty_capacity(Float64, length(request.points), expected_responses) + else + data = empty_capacity(Float64, dims, length(request.points), expected_responses) + end + end + push!(out, (name => (; + type=request.type, + data=data, + ))) + end + return CatRecorder(; + recording=CatRecording(NamedTuple(out), expected_responses), + requests=NamedTuple(requests), + ) + #= + CatRecording( + xs, + points, + ability_ests, + num_questions, + num_respondents, + actual_abilities + ), + AbilityIntegrator(integrator), + raw_estimator, + ability_estimator + =# +end + + +function push_ability_est!(ability_ests::AbstractMatrix{Float64}, col_idx, ability_est) + ability_ests[:, col_idx] = ability_est +end + +function push_ability_est!(ability_ests::AbstractVector{Float64}, col_idx, ability_est) + ability_ests[col_idx] = ability_est +end + +function eachmatcol(xs::AbstractMatrix) + eachcol(xs) +end + +function eachmatcol(xs::AbstractVector) + xs +end + +#= +function save_sampled(xs::Nothing, integrator::RiemannEnumerationIntegrator, + recorder::CatRecorder, tracked_responses, ir, item_correct) + # In this case, the item bank is probably sampled so we can use that + + # Save likelihoods + dist_est = distribution_estimator(recorder.ability_estimator) + denom = normdenom(integrator, dist_est, tracked_responses) + recorder.likelihoods[:, recorder.col_idx] = function_ys( + Aggregators.pdf( + dist_est, + tracked_responses + ) + ) ./ denom + raw_denom = normdenom(integrator, recorder.raw_estimator, tracked_responses) + recorder.raw_likelihoods[:, recorder.col_idx] = function_ys( + Aggregators.pdf( + recorder.raw_estimator, + tracked_responses + ) + ) ./ raw_denom + + # Save item responses + recorder.item_responses[:, recorder.col_idx] = item_ys(ir, item_correct) +end +=# + +function sample_likelihood(tracked_responses, xs, dist_est, integrator) + # Save likelihoods + num = Aggregators.pdf.( + dist_est, + tracked_responses, + eachmatcol(xs) + ) + denom = normdenom(integrator, dist_est, tracked_responses) + return num ./ denom +end + +#= + raw_denom = normdenom(integrator, recorder.raw_estimator, tracked_responses) + recorder.raw_likelihoods[:, recorder.col_idx] = Aggregators.pdf.( + Ref(recorder.raw_estimator), + Ref(tracked_responses), + eachmatcol(xs)) ./ raw_denom +=# + +function service_requests!( + #xs, integrator, recorder::CatRecorder, tracked_responses, ir, item_correct) + recorder::CatRecorder, tracked_responses, ir, item_correct +) + out = recorder.recording.data + for (name, request) in pairs(recorder.requests) + if request.type == :ability_value + push!(out[name].data, request.estimator(tracked_responses)) + elseif request.type == :ability_distribution + likelihood_sample = sample_likelihood(tracked_responses, request.points, request.estimator, request.integrator) + @info "pushing" name size(out[name].data) size(likelihood_sample) + elastic_push!(out[name].data, likelihood_sample) + end + end + + #= + # Save item responses + recorder.item_responses[:, recorder.col_idx] = resp.(Ref(ir), + item_correct, + eachmatcol(xs)) + =# +end + +""" +$(TYPEDSIGNATURES) +""" +function record!(recorder::CatRecorder, tracked_responses) + item_index = tracked_responses.responses.indices[end] + item_correct = tracked_responses.responses.values[end] > 0 + ir = ItemResponse(tracked_responses.item_bank, item_index) + service_requests!(recorder, tracked_responses, ir, item_correct) + record!(recorder.recording, tracked_responses.responses) +end + +function catrecorder_callback(recoder::CatRecorder) + return (tracked_responses, _) -> record!(recoder, tracked_responses) +end diff --git a/src/sim/run.jl b/src/sim/run.jl new file mode 100644 index 0000000..b6327f7 --- /dev/null +++ b/src/sim/run.jl @@ -0,0 +1,84 @@ +""" +$(TYPEDSIGNATURES) + +This response callback simply prompts the user for the response using the console +""" +function prompt_response(index_, label) + println("Response for $label > ") + parse(Int8, readline()) +end + +""" +$(TYPEDSIGNATURES) + +This function constructs a next item function which automatically responds +according to `responses`. +""" +function auto_responder(responses) + function (index, label_) + responses[index] + end +end + +abstract type NextItemError <: Exception end + +function item_label(ib_labels, next_index) + default_next_label(next_index) = "<>" + if ib_labels === nothing + return default_next_label(next_index) + else + return get(default_next_label, ib_labels, next_index) + end +end + +""" +```julia +$(FUNCTIONNAME)(cat_config::CatLoop, item_bank::AbstractItemBank; ib_labels=nothing) +``` + +Run a given [CatLoop](@ref) `cat_config` on the given `item_bank`. +If `ib_labels` is not given, default labels of the form +`<>` are passed to the callback. +""" +function run_cat(loop::CatLoop{RulesT}, + item_bank::AbstractItemBank; + ib_labels = nothing) where {RulesT <: CatRules} + (; rules, get_response, new_response_callback) = loop + (; next_item, termination_condition, ability_estimator, ability_tracker) = rules + responses = TrackedResponses(BareResponses(ResponseType(item_bank)), + item_bank, + ability_tracker) + while true + local next_index + @debug begin + criteria = compute_criteria(next_item, responses, item_bank) + "Best items" + end criteria + try + next_index = best_item(next_item, responses, item_bank) + catch exc + if isa(exc, NextItemError) + @warn "Terminating early due to error getting next item" err=sprint( + showerror, + exc) + break + else + rethrow() + end + end + next_label = item_label(ib_labels, next_index) + @debug "Querying" next_label + response = get_response(next_index, next_label) + @debug "Got response" response + add_response!(responses, Response(ResponseType(item_bank), next_index, response)) + terminating = termination_condition(responses, item_bank) + if new_response_callback !== nothing + new_response_callback(responses, terminating) + end + if terminating + @debug "Met termination condition" + break + end + end + (responses.responses, ability_estimator(responses)) +end \ No newline at end of file diff --git a/test/dt.jl b/test/dt.jl index 8c89b07..cf5daa1 100644 --- a/test/dt.jl +++ b/test/dt.jl @@ -21,7 +21,7 @@ get_response = auto_responder(@view true_responses[:, 1]) termination_condition = termination_condition, ability_estimator = ability_estimator ) - cat_loop_config = CatLoopConfig( + cat_loop_config = CatLoop( rules = cat_rules, get_response = get_response ) @@ -33,7 +33,7 @@ get_response = auto_responder(@view true_responses[:, 1]) ability_estimator = ability_estimator ) dt_materialized = generate_dt_cat(dt_generation_config, item_bank) - dt_loop_config = CatLoopConfig( + dt_loop_config = CatLoop( rules = dt_materialized, get_response = get_response ) @@ -45,7 +45,7 @@ get_response = auto_responder(@view true_responses[:, 1]) tempdir = mktempdir() save_mmap(tempdir, dt_materialized) dt_rt = load_mmap(tempdir) - dt_rt_loop_config = CatLoopConfig( + dt_rt_loop_config = CatLoop( rules = dt_rt, get_response = get_response ) diff --git a/test/runtests.jl b/test/runtests.jl index 75ccddc..beb5390 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,6 @@ using ComputerAdaptiveTesting.Aggregators using FittedItemBanks.DummyData: dummy_full, SimpleItemBankSpec, StdModel3PL, VectorContinuousDomain, BooleanResponse, std_normal using FittedItemBanks -using ComputerAdaptiveTesting.CatConfig using ComputerAdaptiveTesting.Responses using ComputerAdaptiveTesting.NextItemRules using ComputerAdaptiveTesting.TerminationConditions @@ -12,7 +11,7 @@ using ComputerAdaptiveTesting.Sim using PsychometricsBazaarBase.Integrators using PsychometricsBazaarBase.Optimizers using ComputerAdaptiveTesting.DecisionTree -using ComputerAdaptiveTesting: Stateful +using ComputerAdaptiveTesting: Stateful, CatRules using Distributions using Distributions: ZeroMeanIsoNormal, Zeros, ScalMat using Optim diff --git a/test/smoke.jl b/test/smoke.jl index 593d08a..d63402a 100644 --- a/test/smoke.jl +++ b/test/smoke.jl @@ -28,7 +28,7 @@ using .Dummy ) for testee_idx in axes(true_responses, 2) responses, ability = run_cat( - CatLoopConfig( + CatLoop( rules = rules, get_response = auto_responder(@view true_responses[:, testee_idx]) ), From 4f7d7db5861af8c48897c633a7a065a1f87c02df Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 21 Jun 2025 14:47:12 +0300 Subject: [PATCH 29/42] Rename PiecewiseNextItemRule => FixedRuleSequencer --- src/next_item_rules/NextItemRules.jl | 2 +- src/next_item_rules/strategies/pointwise.jl | 4 ++-- src/next_item_rules/strategies/sequential.jl | 13 +++++-------- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index efe0f78..ec04fb6 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -47,7 +47,7 @@ export RawEmpiricalInformationPointwiseItemCategoryCriterion export EmpiricalInformationPointwiseItemCategoryCriterion export TotalItemInformation export RandomNextItemRule -export PiecewiseNextItemRule, MemoryNextItemRule, FixedFirstItemNextItemRule +export FixedRuleSequencer, MemoryNextItemRule, FixedFirstItemNextItemRule export ExhaustiveSearch, RandomesqueStrategy export preallocate export compute_criteria, compute_criterion, compute_multi_criterion diff --git a/src/next_item_rules/strategies/pointwise.jl b/src/next_item_rules/strategies/pointwise.jl index 80c0709..763d40d 100644 --- a/src/next_item_rules/strategies/pointwise.jl +++ b/src/next_item_rules/strategies/pointwise.jl @@ -17,5 +17,5 @@ function best_item(rule::PointwiseNextItemRule, responses::TrackedResponses, ite end function PointwiseFirstNextItemRule(criterion, points, rule) - PiecewiseNextItemRule((length(points),), (PointwiseNextItemRule(criterion, points), rule)) -end \ No newline at end of file + FixedRuleSequencer((length(points),), (PointwiseNextItemRule(criterion, points), rule)) +end diff --git a/src/next_item_rules/strategies/sequential.jl b/src/next_item_rules/strategies/sequential.jl index 845e3ae..ff77b60 100644 --- a/src/next_item_rules/strategies/sequential.jl +++ b/src/next_item_rules/strategies/sequential.jl @@ -2,11 +2,8 @@ $(TYPEDEF) $(TYPEDFIELDS) -This is the most basic rule for choosing the next item in a CAT. It simply -picks a random item from the set of items that have not yet been -administered. """ -@kwdef struct PiecewiseNextItemRule{RulesT} <: NextItemRule +@kwdef struct FixedRuleSequencer{RulesT} <: NextItemRule # Tuple of Ints breaks::Tuple{Int} # Tuple of NextItemRules @@ -15,7 +12,7 @@ end #tuple_len(::NTuple{N, Any}) where {N} = Val{N}() -function current_rule(rule::PiecewiseNextItemRule, responses::TrackedResponses) +function current_rule(rule::FixedRuleSequencer, responses::TrackedResponses) for brk in 1:length(rule.breaks) if length(responses) < rule.breaks[brk] return rule.rules[brk] @@ -24,11 +21,11 @@ function current_rule(rule::PiecewiseNextItemRule, responses::TrackedResponses) return rule.rules[end] end -function best_item(rule::PiecewiseNextItemRule, responses::TrackedResponses, items) +function best_item(rule::FixedRuleSequencer, responses::TrackedResponses, items) return best_item(current_rule(rule, responses), responses, items) end -function compute_criteria(rule::PiecewiseNextItemRule, responses::TrackedResponses) +function compute_criteria(rule::FixedRuleSequencer, responses::TrackedResponses) return compute_criteria(current_rule(rule, responses), responses) end @@ -47,5 +44,5 @@ function best_item(rule::MemoryNextItemRule, responses::TrackedResponses, _items end function FixedFirstItemNextItemRule(item_idx::Int, rule::NextItemRule) - PiecewiseNextItemRule((1,), (MemoryNextItemRule((item_idx,)), rule)) + FixedRuleSequencer((1,), (MemoryNextItemRule((item_idx,)), rule)) end \ No newline at end of file From 78aab0798f6d7837cc611d94be1dd73df89101a5 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 21 Jun 2025 15:11:57 +0300 Subject: [PATCH 30/42] Add/improve show(...) methods --- src/aggregators/Aggregators.jl | 1 + src/aggregators/ability_estimator.jl | 9 +++++++ src/aggregators/optimizers.jl | 16 ++++++++++++ .../combinators/expectation.jl | 17 ++++++++++++ .../criteria/pointwise/information.jl | 26 +++++++++++++++++++ .../criteria/state/ability_variance.jl | 9 +++++++ src/next_item_rules/prelude/next_item_rule.jl | 12 +++++---- src/next_item_rules/strategies/balance.jl | 3 +-- src/next_item_rules/strategies/pointwise.jl | 10 +++++++ src/next_item_rules/strategies/randomesque.jl | 4 +++ src/next_item_rules/strategies/sequential.jl | 16 ++++++++++++ 11 files changed, 116 insertions(+), 7 deletions(-) diff --git a/src/aggregators/Aggregators.jl b/src/aggregators/Aggregators.jl index bd302df..f4b7281 100644 --- a/src/aggregators/Aggregators.jl +++ b/src/aggregators/Aggregators.jl @@ -33,6 +33,7 @@ using PsychometricsBazaarBase.Integrators: Integrators, using PsychometricsBazaarBase.Optimizers: OneDimOptimOptimizer, Optimizer using PsychometricsBazaarBase.ConstDistributions: std_normal, std_mv_normal import Distributions: pdf +import Base: show import FittedItemBanks import PsychometricsBazaarBase.IntegralCoeffs diff --git a/src/aggregators/ability_estimator.jl b/src/aggregators/ability_estimator.jl index da468cc..c81de70 100644 --- a/src/aggregators/ability_estimator.jl +++ b/src/aggregators/ability_estimator.jl @@ -212,6 +212,15 @@ function ModeAbilityEstimator(bits...) ModeAbilityEstimator(dist_est, optimizer) end +function show(io::IO, ::MIME"text/plain", ability_estimator::ModeAbilityEstimator) + println(io, "Estimate ability using its mode") + indent_io = indent(io, 2; skip_first=true) + print(indent_io, "Distribution estimator ") + show(indent_io, ability_estimator.dist_est) + print(indent_io, "Optimizer: ") + show(indent_io, ability_estimator.optim) +end + struct MeanAbilityEstimator{ DistEst <: DistributionAbilityEstimator, IntegratorT <: AbilityIntegrator diff --git a/src/aggregators/optimizers.jl b/src/aggregators/optimizers.jl index 6585c28..2f77127 100644 --- a/src/aggregators/optimizers.jl +++ b/src/aggregators/optimizers.jl @@ -10,6 +10,22 @@ function (optim::FunctionOptimizer)(f::F, optim.optim(comp_f) end +function show(io::IO, ::MIME"text/plain", optim::FunctionOptimizer) + indent_io = indent(io, 2) + if optim.optim isa OneDimOptimOptimizer || optim.optim isa MultiDimOptimOptimizer || optim.optim isa NativeOneDimOptimOptimizer + inner = optim.optim + println(io, "Optimizer:") + if optim.optim isa NativeOneDimOptimOptimizer + name = typeof(inner.method).name.name + else + name = typeof(inner.optim).name.name + end + print(indent_io, "Method: ", name) + print(indent_io, "Lo: ", inner.lo) + print(indent_io, "Hi: ", inner.hi) + end +end + #= """ Argmax + max over the ability likihood given a set of responses with a given diff --git a/src/next_item_rules/combinators/expectation.jl b/src/next_item_rules/combinators/expectation.jl index 87c945b..57d7b69 100644 --- a/src/next_item_rules/combinators/expectation.jl +++ b/src/next_item_rules/combinators/expectation.jl @@ -39,6 +39,14 @@ function Aggregators.response_expectation( item_idx) end +function show(io::IO, ::MIME"text/plain", point_response_expectation::PointResponseExpectation) + println(io, "Expected response at point ability estimate") + indent_io = indent(io, 2) + print(indent_io, "Ability estimator: ") + summary(indent_io, point_response_expectation.ability_estimator) + println(io) +end + struct DistributionResponseExpectation{ DistributionAbilityEstimatorT <: DistributionAbilityEstimator, AbilityIntegratorT <: AbilityIntegrator @@ -124,3 +132,12 @@ function compute_criterion( end res end + +function show(io::IO, ::MIME"text/plain", item_criterion::ExpectationBasedItemCriterion) + println(io, "Optimize an state/item/item-category criterion based on an expected response") + indent_io = indent(io, 2) + print(indent_io, "Expected response obtained by: ") + show(indent_io, MIME"text/plain"(), item_criterion.response_expectation) + print(indent_io, "Criterion: ") + show(indent_io, MIME"text/plain"(), item_criterion.criterion) +end \ No newline at end of file diff --git a/src/next_item_rules/criteria/pointwise/information.jl b/src/next_item_rules/criteria/pointwise/information.jl index a9845e6..1104bef 100644 --- a/src/next_item_rules/criteria/pointwise/information.jl +++ b/src/next_item_rules/criteria/pointwise/information.jl @@ -22,6 +22,10 @@ function compute_criterion_vec( -actual end +function show(io::IO, ::MIME"text/plain", ::ObservedInformationPointwiseItemCategoryCriterion) + println(io, "Observed pointwise item-category information") +end + """ See EmpiricalInformationPointwiseItemCategoryCriterion for more details. """ @@ -46,6 +50,11 @@ function compute_criterion_vec( -actual end + +function show(io::IO, ::MIME"text/plain", ::RawEmpiricalInformationPointwiseItemCategoryCriterion) + println(io, "Raw empirical pointwise item-category information") +end + """ In equation 10 of [1] we see that we can compute information using 2nd derivatives of log likelihood or 1st derivative squared. For single categories, we need to an extra term which disappears when we calculate the total see [2]. @@ -95,6 +104,10 @@ function compute_criterion_vec( -actual end +function show(io::IO, ::MIME"text/plain", ::EmpiricalInformationPointwiseItemCategoryCriterion) + println(io, "Empirical pointwise item-category information") +end + #= """ This implements Fisher information as a pointwise item criterion. @@ -116,4 +129,17 @@ function compute_criterion( ability ) sum(compute_criterion_vec(tii.pcic, ir, ability)) +end + +function show(io::IO, ::MIME"text/plain", rule::TotalItemInformation) + if rule.pcic isa ObservedInformationPointwiseItemCategoryCriterion + println(io, "Observed pointwise item information") + elseif rule.pcic isa RawEmpiricalInformationPointwiseItemCategoryCriterion + println(io, "Raw empirical pointwise item information") + elseif rule.pcic isa EmpiricalInformationPointwiseItemCategoryCriterion + println(io, "Empirical pointwise item information") + else + print(io, "Total ") + show(io, MIME("text/plain"), rule.pcic) + end end \ No newline at end of file diff --git a/src/next_item_rules/criteria/state/ability_variance.jl b/src/next_item_rules/criteria/state/ability_variance.jl index 6e44b49..19f2eab 100644 --- a/src/next_item_rules/criteria/state/ability_variance.jl +++ b/src/next_item_rules/criteria/state/ability_variance.jl @@ -68,6 +68,15 @@ function compute_criterion( denom) end +function show(io::IO, ::MIME"text/plain", criterion::AbilityVarianceStateCriterion) + println(io, "Minimise variance of ability estimate") + indent_io = indent(io, 2) + print(indent_io, "Distribution estimator: ") + show(indent_io, MIME"text/plain"(), criterion.dist_est) + print(indent_io, "Integrator: ") + show(indent_io, MIME"text/plain"(), criterion.integrator) +end + struct AbilityCovarianceStateMultiCriterion{ DistEstT <: DistributionAbilityEstimator, IntegratorT <: AbilityIntegrator diff --git a/src/next_item_rules/prelude/next_item_rule.jl b/src/next_item_rules/prelude/next_item_rule.jl index f0311f2..7b68b3f 100644 --- a/src/next_item_rules/prelude/next_item_rule.jl +++ b/src/next_item_rules/prelude/next_item_rule.jl @@ -53,11 +53,13 @@ function best_item(rule::NextItemRule, tracked_responses::TrackedResponses) best_item(rule, tracked_responses, tracked_responses.item_bank) end -function Base.show(io::IO, ::MIME"text/plain", next_item_rule::ItemStrategyNextItemRule) - println(io, "Strategy:") - show(indent_io, MIME"text/plain"(), rules.strategy) - println(io, "Item criterion:") - show(indent_io, MIME"text/plain"(), rules.criterion) +function Base.show(io::IO, ::MIME"text/plain", rule::ItemStrategyNextItemRule) + println(io, "Pick optimal item criterion according to strategy") + indent_io = indent(io, 2) + print(indent_io, "Strategy: ") + show(indent_io, MIME"text/plain"(), rule.strategy) + print(indent_io, "Item criterion: ") + show(indent_io, MIME"text/plain"(), rule.criterion) end # Default implementation diff --git a/src/next_item_rules/strategies/balance.jl b/src/next_item_rules/strategies/balance.jl index 2d3e6cd..5adcd20 100644 --- a/src/next_item_rules/strategies/balance.jl +++ b/src/next_item_rules/strategies/balance.jl @@ -41,9 +41,8 @@ end function show(io::IO, ::MIME"text/plain", rule::GreedyForcedContentBalancer) indent_io = indent(io, 2) - println(io, "Greedy + forced content balancer") + println(io, "Greedy + forced content balancing") println(indent_io, "Target ratio: " * join(rule.targets, ", ")) - print(indent_io, "Using rule: ") show(indent_io, MIME("text/plain"), rule.inner_rule) end diff --git a/src/next_item_rules/strategies/pointwise.jl b/src/next_item_rules/strategies/pointwise.jl index 763d40d..058e48e 100644 --- a/src/next_item_rules/strategies/pointwise.jl +++ b/src/next_item_rules/strategies/pointwise.jl @@ -16,6 +16,16 @@ function best_item(rule::PointwiseNextItemRule, responses::TrackedResponses, ite return idx end +function show(io::IO, ::MIME"text/plain", rule::PointwiseNextItemRule) + println(io, "Optimize a pointwise criterion at specified points") + indent_io = indent(io, 2) + points_desc = join(rule.points, ", ") + println(indent_io, "Points: $points_desc") + print(indent_io, "Criterion: ") + show(indent_io, MIME("text/plain"), rule.criterion) +end + + function PointwiseFirstNextItemRule(criterion, points, rule) FixedRuleSequencer((length(points),), (PointwiseNextItemRule(criterion, points), rule)) end diff --git a/src/next_item_rules/strategies/randomesque.jl b/src/next_item_rules/strategies/randomesque.jl index 2d0a6bf..4ae7e11 100644 --- a/src/next_item_rules/strategies/randomesque.jl +++ b/src/next_item_rules/strategies/randomesque.jl @@ -51,4 +51,8 @@ function best_item( items ) where {ItemCriterionT <: ItemCriterion} randomesque(rule.strategy.rng, rule.criterion, responses, items, rule.strategy.k)[1] +end + +function show(io::IO, ::MIME"text/plain", rule::RandomesqueStrategy) + println(io, "Randomesque strategy with k = $(rule.k)") end \ No newline at end of file diff --git a/src/next_item_rules/strategies/sequential.jl b/src/next_item_rules/strategies/sequential.jl index ff77b60..265cba9 100644 --- a/src/next_item_rules/strategies/sequential.jl +++ b/src/next_item_rules/strategies/sequential.jl @@ -29,6 +29,17 @@ function compute_criteria(rule::FixedRuleSequencer, responses::TrackedResponses) return compute_criteria(current_rule(rule, responses), responses) end +function show(io::IO, ::MIME"text/plain", rule::FixedRuleSequencer) + indent_io = indent(io, 2) + println(io, "Fixed rule sequencing:") + print(indent_io, "Firstly: ") + show(indent_io, MIME("text/plain"), rule.rules[1]) + for (responses, rule) in zip(rule.breaks, rule.rules[2:end]) + print(indent_io, "After $responses responses: ") + show(indent_io, MIME("text/plain"), rule) + end +end + """ """ @kwdef struct MemoryNextItemRule{MemoryT} <: NextItemRule @@ -43,6 +54,11 @@ function best_item(rule::MemoryNextItemRule, responses::TrackedResponses, _items # TODO: Add some basic error checking -- can only panic end +function show(io::IO, ::MIME"text/plain", rule::MemoryNextItemRule) + item_list = join(rule.item_idxs, ", ") + println(io, "Present the items indexed: $item_list") +end + function FixedFirstItemNextItemRule(item_idx::Int, rule::NextItemRule) FixedRuleSequencer((1,), (MemoryNextItemRule((item_idx,)), rule)) end \ No newline at end of file From 42f4f341f732299b3ced3fdc28d9e6b5d23e8e1f Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 21 Jun 2025 15:13:21 +0300 Subject: [PATCH 31/42] Remove trailing space --- src/aggregators/optimizers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aggregators/optimizers.jl b/src/aggregators/optimizers.jl index 2f77127..f7b2376 100644 --- a/src/aggregators/optimizers.jl +++ b/src/aggregators/optimizers.jl @@ -48,7 +48,7 @@ function (optim::EnumerationOptimizer)(f::F, ability_likelihood.item_bank; lo = lo, hi = hi) do (x, prob) - # @inline + # @inline fprob = f(x) * prob if fprob >= cur_max[] cur_argmax[] = x From ca8b99a7205c3d84761239e8e2987fca00d41a0a Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 21 Jun 2025 15:13:49 +0300 Subject: [PATCH 32/42] Add some docs to MemoryNextItemRule --- src/next_item_rules/strategies/sequential.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/next_item_rules/strategies/sequential.jl b/src/next_item_rules/strategies/sequential.jl index 265cba9..1c0c59c 100644 --- a/src/next_item_rules/strategies/sequential.jl +++ b/src/next_item_rules/strategies/sequential.jl @@ -41,6 +41,9 @@ function show(io::IO, ::MIME"text/plain", rule::FixedRuleSequencer) end """ +$(TYPEDEF) +$(TYPEDFIELDS) + """ @kwdef struct MemoryNextItemRule{MemoryT} <: NextItemRule item_idxs::MemoryT @@ -51,6 +54,7 @@ function best_item(rule::MemoryNextItemRule, responses::TrackedResponses, _items # XXX: A few problems with this: # 1. Could run out of `item_idxs` # 2. Could return an item not in `items` + # 3: Will not work if this is sequenced after items have already been administered # TODO: Add some basic error checking -- can only panic end From a5fb03dc87c1b215fbc890e6bffabdd67a8bda8b Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 21 Jun 2025 15:18:50 +0300 Subject: [PATCH 33/42] Rename PriorAbilityEstimator => PosteriorAbilityEstimator --- benchmark/benchmarks.jl | 2 +- docs/examples/examples/ability_convergence_3pl.jl | 4 ++-- docs/examples/examples/ability_convergence_mirt.jl | 4 ++-- docs/examples/examples/vocab_iq.jl | 6 +++--- profile/next_items.jl | 2 +- src/Compat/CatR.jl | 10 +++++----- src/Compat/MirtCAT.jl | 10 +++++----- src/aggregators/Aggregators.jl | 2 +- src/aggregators/ability_estimator.jl | 14 +++++++------- .../ability_trackers/closed_form_normal.jl | 2 +- src/sim/Sim.jl | 2 +- test/ability_estimator_1d.jl | 2 +- test/ability_estimator_2d.jl | 2 +- test/dummy.jl | 4 ++-- 14 files changed, 33 insertions(+), 33 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index ff1c61e..ec04dcf 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -30,7 +30,7 @@ function prepare_4pls(group) integrator = even_grid(-6.0, 6.0, mirtcat_quadpts(1)) optimizer = AbilityOptimizer(OneDimOptimOptimizer(-6.0, 6.0, NelderMead())) - dist_ability_estimator = PriorAbilityEstimator() + dist_ability_estimator = PosteriorAbilityEstimator() ability_estimators = [ ("mean", MeanAbilityEstimator(dist_ability_estimator, integrator)), ("mode", ModeAbilityEstimator(dist_ability_estimator, optimizer)) diff --git a/docs/examples/examples/ability_convergence_3pl.jl b/docs/examples/examples/ability_convergence_3pl.jl index eefc264..744fb8b 100644 --- a/docs/examples/examples/ability_convergence_3pl.jl +++ b/docs/examples/examples/ability_convergence_3pl.jl @@ -22,7 +22,7 @@ using ComputerAdaptiveTesting using ComputerAdaptiveTesting.Sim: auto_responder using ComputerAdaptiveTesting.NextItemRules: AbilityVarianceStateCriterion using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition -using ComputerAdaptiveTesting.Aggregators: PriorAbilityEstimator, +using ComputerAdaptiveTesting.Aggregators: PosteriorAbilityEstimator, MeanAbilityEstimator, LikelihoodAbilityEstimator using FittedItemBanks using ComputerAdaptiveTesting.Responses: BooleanResponse @@ -46,7 +46,7 @@ using FittedItemBanks.DummyData: dummy_full, std_normal, SimpleItemBankSpec, Std # CatRecorder collects information which can be used to draw different types of plots. max_questions = 99 integrator = FixedGKIntegrator(-6, 6, 80) -dist_ability_est = PriorAbilityEstimator(std_normal) +dist_ability_est = PosteriorAbilityEstimator(std_normal) ability_estimator = MeanAbilityEstimator(dist_ability_est, integrator) rules = CatRules(ability_estimator, AbilityVarianceStateCriterion(dist_ability_est, integrator), diff --git a/docs/examples/examples/ability_convergence_mirt.jl b/docs/examples/examples/ability_convergence_mirt.jl index e481b87..adb480f 100644 --- a/docs/examples/examples/ability_convergence_mirt.jl +++ b/docs/examples/examples/ability_convergence_mirt.jl @@ -22,7 +22,7 @@ using ComputerAdaptiveTesting using ComputerAdaptiveTesting.Sim: auto_responder using ComputerAdaptiveTesting.NextItemRules: DRuleItemCriterion using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition -using ComputerAdaptiveTesting.Aggregators: PriorAbilityEstimator, +using ComputerAdaptiveTesting.Aggregators: PosteriorAbilityEstimator, MeanAbilityEstimator, LikelihoodAbilityEstimator using FittedItemBanks import PsychometricsBazaarBase.IntegralCoeffs @@ -49,7 +49,7 @@ using ComputerAdaptiveTesting.Responses: BooleanResponse # CatRecorder collects information which can be used to draw different types of plots. max_questions = 9 integrator = CubaIntegrator([-6.0, -6.0], [6.0, 6.0], CubaVegas(); rtol = 1e-2) -ability_estimator = MeanAbilityEstimator(PriorAbilityEstimator(std_mv_normal(dims)), +ability_estimator = MeanAbilityEstimator(PosteriorAbilityEstimator(std_mv_normal(dims)), integrator) rules = CatRules(ability_estimator, DRuleItemCriterion(ability_estimator), diff --git a/docs/examples/examples/vocab_iq.jl b/docs/examples/examples/vocab_iq.jl index 4c67479..0434daf 100644 --- a/docs/examples/examples/vocab_iq.jl +++ b/docs/examples/examples/vocab_iq.jl @@ -5,9 +5,9 @@ # --- #md # Running a CAT based based on real response data -# +# # This example shows how to run a CAT end-to-end on real data. -# +# # First a 1-dimensional IRT model is fitted based on open response data to the # vocabulary IQ test using the IRTSupport package which internally, this uses # the `mirt` R package. Next, the model is used to administer the test @@ -37,7 +37,7 @@ function run_vocab_iq_cat() item_bank, labels = get_item_bank() integrator = FixedGKIntegrator(-6, 6, 61) ability_integrator = AbilityIntegrator(integrator) - dist_ability_est = PriorAbilityEstimator(std_normal) + dist_ability_est = PosteriorAbilityEstimator(std_normal) optimizer = AbilityOptimizer(OneDimOptimOptimizer(-6.0, 6.0, NelderMead())) ability_estimator = ModeAbilityEstimator(dist_ability_est, optimizer) @info "run_cat" ability_estimator diff --git a/profile/next_items.jl b/profile/next_items.jl index e23ad07..ef8b5fc 100644 --- a/profile/next_items.jl +++ b/profile/next_items.jl @@ -20,7 +20,7 @@ function get_ability_estimator(multidim) integrator = FixedGKIntegrator(-6.0, 6.0) dist = Normal() end - return PriorAbilityEstimator(dist, integrator) + return PosteriorAbilityEstimator(dist, integrator) end function prepare_empty(item_bank, actual_responses, ability_tracker) diff --git a/src/Compat/CatR.jl b/src/Compat/CatR.jl index 6af9eb9..9da15c3 100644 --- a/src/Compat/CatR.jl +++ b/src/Compat/CatR.jl @@ -5,7 +5,7 @@ using ComputerAdaptiveTesting.Aggregators: AbilityIntegrator, DistributionAbilityEstimator, ModeAbilityEstimator, MeanAbilityEstimator, - PriorAbilityEstimator + PosteriorAbilityEstimator using ComputerAdaptiveTesting.TerminationConditions: RunForeverTerminationCondition using ComputerAdaptiveTesting.Rules: CatRules using ComputerAdaptiveTesting.NextItemRules @@ -51,9 +51,9 @@ const next_item_aliases = _next_item_aliases() function _ability_estimator_aliases() res = Dict{String, Any}() - res["BM"] = (; optimizer, kwargs...) -> ModeAbilityEstimator(PriorAbilityEstimator(), optimizer) + res["BM"] = (; optimizer, kwargs...) -> ModeAbilityEstimator(PosteriorAbilityEstimator(), optimizer) res["ML"] = (; optimizer, kwargs...) -> ModeAbilityEstimator(LikelihoodAbilityEstimator(), optimizer) - res["EAP"] = (; integrator, kwargs...) -> MeanAbilityEstimator(PriorAbilityEstimator(), integrator) + res["EAP"] = (; integrator, kwargs...) -> MeanAbilityEstimator(PosteriorAbilityEstimator(), integrator) #res["WL"] #res["ROB"] return res @@ -97,7 +97,7 @@ function assemble_rules(; integrator = setup_integrator() optimizer = setup_optimizer() ability_estimator = ability_estimator_aliases[method](; integrator, optimizer) - posterior_ability_estimator = PriorAbilityEstimator() + posterior_ability_estimator = PosteriorAbilityEstimator() raw_next_item = next_item_aliases[criterion](ability_estimator, integrator, optimizer; posterior_ability_estimator=posterior_ability_estimator) next_item = FixedFirstItemNextItemRule(start_item, raw_next_item) CatRules(; @@ -108,4 +108,4 @@ function assemble_rules(; ) end -end \ No newline at end of file +end diff --git a/src/Compat/MirtCAT.jl b/src/Compat/MirtCAT.jl index 97706e7..8a3d7a2 100644 --- a/src/Compat/MirtCAT.jl +++ b/src/Compat/MirtCAT.jl @@ -5,7 +5,7 @@ using ComputerAdaptiveTesting.Aggregators: SafeLikelihoodAbilityEstimator, DistributionAbilityEstimator, ModeAbilityEstimator, MeanAbilityEstimator, - PriorAbilityEstimator, + PosteriorAbilityEstimator, AbilityEstimator, distribution_estimator using ComputerAdaptiveTesting.TerminationConditions: RunForeverTerminationCondition @@ -71,9 +71,9 @@ to randomly select items, and 'seq' for selecting items sequentially =# const ability_estimator_aliases = Dict( - "MAP" => (; optimizer, ncomp, kwargs...) -> ModeAbilityEstimator(PriorAbilityEstimator(; ncomp=ncomp), optimizer), + "MAP" => (; optimizer, ncomp, kwargs...) -> ModeAbilityEstimator(PosteriorAbilityEstimator(; ncomp=ncomp), optimizer), "ML" => (; optimizer, ncomp, kwargs...) -> ModeAbilityEstimator(SafeLikelihoodAbilityEstimator(; ncomp=ncomp), optimizer), - "EAP" => (; integrator, ncomp, kwargs...) -> MeanAbilityEstimator(PriorAbilityEstimator(; ncomp=ncomp), integrator), + "EAP" => (; integrator, ncomp, kwargs...) -> MeanAbilityEstimator(PosteriorAbilityEstimator(; ncomp=ncomp), integrator), # "WLE" for weighted likelihood estimation # "EAPsum" for the expected a-posteriori for each sum score ) @@ -143,7 +143,7 @@ function assemble_rules(; integrator = setup_integrator(lo, hi, pts) optimizer = setup_optimizer(-theta_lim, theta_lim) ability_estimator = ability_estimator_aliases[method](; integrator, optimizer, ncomp) - posterior_ability_estimator = PriorAbilityEstimator(; ncomp) + posterior_ability_estimator = PosteriorAbilityEstimator(; ncomp) raw_next_item = next_item_aliases[criteria](ability_estimator, posterior_ability_estimator, integrator, optimizer) next_item = FixedFirstItemNextItemRule(start_item, raw_next_item) CatRules(; @@ -153,4 +153,4 @@ function assemble_rules(; ) end -end \ No newline at end of file +end diff --git a/src/aggregators/Aggregators.jl b/src/aggregators/Aggregators.jl index f4b7281..296eb26 100644 --- a/src/aggregators/Aggregators.jl +++ b/src/aggregators/Aggregators.jl @@ -42,7 +42,7 @@ export AbilityEstimator, TrackedResponses export AbilityTracker, NullAbilityTracker, PointAbilityTracker, GriddedAbilityTracker export ClosedFormNormalAbilityTracker, track! export response_expectation, expectation, distribution_estimator -export PointAbilityEstimator, PriorAbilityEstimator +export PointAbilityEstimator, PosteriorAbilityEstimator export SafeLikelihoodAbilityEstimator, LikelihoodAbilityEstimator export ModeAbilityEstimator, MeanAbilityEstimator export Speculator, replace_speculation!, normdenom, maybe_tracked_ability_estimate diff --git a/src/aggregators/ability_estimator.jl b/src/aggregators/ability_estimator.jl index c81de70..4739a74 100644 --- a/src/aggregators/ability_estimator.jl +++ b/src/aggregators/ability_estimator.jl @@ -26,19 +26,19 @@ function pdf(::LikelihoodAbilityEstimator, AbilityLikelihood(tracked_responses) end -struct PriorAbilityEstimator{PriorT <: Distribution} <: DistributionAbilityEstimator +struct PosteriorAbilityEstimator{PriorT <: Distribution} <: DistributionAbilityEstimator prior::PriorT end -function PriorAbilityEstimator(; ncomp = 0) +function PosteriorAbilityEstimator(; ncomp = 0) if ncomp == 0 - return PriorAbilityEstimator(std_normal) + return PosteriorAbilityEstimator(std_normal) else - return PriorAbilityEstimator(std_mv_normal(ncomp)) + return PosteriorAbilityEstimator(std_mv_normal(ncomp)) end end -function pdf(est::PriorAbilityEstimator, +function pdf(est::PosteriorAbilityEstimator, tracked_responses::TrackedResponses) IntegralCoeffs.PriorApply(IntegralCoeffs.Prior(est.prior), AbilityLikelihood(tracked_responses)) @@ -75,7 +75,7 @@ end function SafeLikelihoodAbilityEstimator(args...; kwargs...) GuardedAbilityEstimator( LikelihoodAbilityEstimator(), - PriorAbilityEstimator(args...), + PosteriorAbilityEstimator(args...), multiple_response_types_guard ) end @@ -289,7 +289,7 @@ function (est::MeanAbilityEstimator{AbilityEstimatorT, RiemannEnumerationIntegra tracked_responses) end -function maybe_apply_prior(f::F, est::PriorAbilityEstimator) where {F} +function maybe_apply_prior(f::F, est::PosteriorAbilityEstimator) where {F} IntegralCoeffs.PriorApply(IntegralCoeffs.Prior(est.prior), f) end diff --git a/src/aggregators/ability_trackers/closed_form_normal.jl b/src/aggregators/ability_trackers/closed_form_normal.jl index 0c81ffc..83505ea 100644 --- a/src/aggregators/ability_trackers/closed_form_normal.jl +++ b/src/aggregators/ability_trackers/closed_form_normal.jl @@ -2,7 +2,7 @@ mutable struct ClosedFormNormalAbilityTracker <: AbilityTracker cur_ability::VarNormal end -function ClosedFormNormalAbilityTracker(prior_ability_estimator::PriorAbilityEstimator) +function ClosedFormNormalAbilityTracker(prior_ability_estimator::PosteriorAbilityEstimator) @warn "ClosedFormNormalAbilityTracker is based on equations from Liden 1998 / Owen 1975, but these appear to give poor results" prior = prior_ability_estimator.prior if !(prior isa Normal) diff --git a/src/sim/Sim.jl b/src/sim/Sim.jl index 5cfe830..3ad7f43 100644 --- a/src/sim/Sim.jl +++ b/src/sim/Sim.jl @@ -15,7 +15,7 @@ using ..Aggregators: TrackedResponses, AbilityIntegrator, AbilityEstimator, LikelihoodAbilityEstimator, - PriorAbilityEstimator, + PosteriorAbilityEstimator, ModeAbilityEstimator, MeanAbilityEstimator, LikelihoodAbilityEstimator, diff --git a/test/ability_estimator_1d.jl b/test/ability_estimator_1d.jl index 7c80037..0869c9c 100644 --- a/test/ability_estimator_1d.jl +++ b/test/ability_estimator_1d.jl @@ -32,7 +32,7 @@ tracked_responses_1d = TrackedResponses(responses_1d, item_bank_1d, NullAbilityT integrator_1d = AbilityIntegrator(FixedGKIntegrator(-6.0, 6.0, 61)) optimizer_1d = AbilityOptimizer(OneDimOptimOptimizer(-6.0, 6.0, NelderMead())) lh_est_1d = LikelihoodAbilityEstimator() -pa_est_1d = PriorAbilityEstimator(Normal(1.0, 0.2)) +pa_est_1d = PosteriorAbilityEstimator(Normal(1.0, 0.2)) eap_1d = MeanAbilityEstimator(pa_est_1d, integrator_1d) map_1d = ModeAbilityEstimator(pa_est_1d, optimizer_1d) mle_mean_1d = MeanAbilityEstimator(lh_est_1d, integrator_1d) diff --git a/test/ability_estimator_2d.jl b/test/ability_estimator_2d.jl index c87cc8b..7e820d7 100644 --- a/test/ability_estimator_2d.jl +++ b/test/ability_estimator_2d.jl @@ -34,7 +34,7 @@ integrator_2d = AbilityIntegrator(MultiDimFixedGKIntegrator([-6.0, -6.0], [6.0, optimizer_2d = AbilityOptimizer(MultiDimOptimOptimizer( [-6.0, -6.0], [6.0, 6.0], NelderMead())) lh_est_2d = LikelihoodAbilityEstimator() -pa_est_2d = PriorAbilityEstimator(MvNormal([1.0, 1.0], ScalMat(2, 0.2))) +pa_est_2d = PosteriorAbilityEstimator(MvNormal([1.0, 1.0], ScalMat(2, 0.2))) eap_2d = MeanAbilityEstimator(pa_est_2d, integrator_2d) map_2d = ModeAbilityEstimator(pa_est_2d, optimizer_2d) mle_mean_2d = MeanAbilityEstimator(lh_est_2d, integrator_2d) diff --git a/test/dummy.jl b/test/dummy.jl index bb5b56e..20e644e 100644 --- a/test/dummy.jl +++ b/test/dummy.jl @@ -28,9 +28,9 @@ const integrators_1d = [ ] const ability_estimators_1d = [ ((:integrator,), - (stuff) -> MeanAbilityEstimator(PriorAbilityEstimator(std_normal), stuff.integrator)), + (stuff) -> MeanAbilityEstimator(PosteriorAbilityEstimator(std_normal), stuff.integrator)), ((:optimizer,), - (stuff) -> ModeAbilityEstimator(PriorAbilityEstimator(std_normal), stuff.optimizer)), + (stuff) -> ModeAbilityEstimator(PosteriorAbilityEstimator(std_normal), stuff.optimizer)), ((:integrator,), (stuff) -> MeanAbilityEstimator(LikelihoodAbilityEstimator(), stuff.integrator)), ((:optimizer,), From 7158dd64d914c41f54acdd2235e2a618aaacd0e7 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Wed, 2 Jul 2025 00:52:52 +0300 Subject: [PATCH 34/42] Fix up import in Compat.MirtCAT --- src/Compat/MirtCAT.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Compat/MirtCAT.jl b/src/Compat/MirtCAT.jl index 8a3d7a2..ae2865b 100644 --- a/src/Compat/MirtCAT.jl +++ b/src/Compat/MirtCAT.jl @@ -10,7 +10,7 @@ using ComputerAdaptiveTesting.Aggregators: SafeLikelihoodAbilityEstimator, distribution_estimator using ComputerAdaptiveTesting.TerminationConditions: RunForeverTerminationCondition using ComputerAdaptiveTesting.NextItemRules -using ComputerAdaptiveTesting: CatRules +using ComputerAdaptiveTesting.Rules: CatRules using PsychometricsBazaarBase: Integrators, Optimizers public next_item_aliases, ability_estimator_aliases, assemble_rules From be8bcbfefe92bff0b6d664abbb0eafcd65b49f62 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Wed, 2 Jul 2025 00:55:00 +0300 Subject: [PATCH 35/42] Improve show methods --- src/Rules.jl | 15 +++++----- src/TerminationConditions.jl | 5 ++++ src/aggregators/Aggregators.jl | 7 ++++- src/aggregators/ability_estimator.jl | 28 +++++++++++++++---- src/aggregators/optimizers.jl | 10 +++---- .../combinators/expectation.jl | 8 ++---- .../criteria/state/ability_variance.jl | 6 ++-- src/next_item_rules/prelude/next_item_rule.jl | 4 +-- src/next_item_rules/strategies/pointwise.jl | 1 - 9 files changed, 52 insertions(+), 32 deletions(-) diff --git a/src/Rules.jl b/src/Rules.jl index 40e7f34..3ec8c02 100644 --- a/src/Rules.jl +++ b/src/Rules.jl @@ -82,13 +82,14 @@ function CatRules(bits...) end function show(io::IO, ::MIME"text/plain", rules::CatRules) - indent_io = indent(io, 2) - println(io, "Next item rule:") - show(indent_io, MIME"text/plain"(), rules.next_item) - println(io, "Termination condition:") - show(indent_io, MIME"text/plain"(), rules.termination_condition) - println(io, "Ability estimator:") - show(indent_io, MIME"text/plain"(), rules.ability_estimator) + print(io, "Next item rule: ") + show(io, MIME("text/plain"), rules.next_item) + println(io) + print(io, "Termination condition: ") + show(io, MIME("text/plain"), rules.termination_condition) + println(io) + print(io, "Ability estimator: ") + show(io, MIME("text/plain"), rules.ability_estimator) end function _find_ability_estimator_and_tracker(bits...) diff --git a/src/TerminationConditions.jl b/src/TerminationConditions.jl index 5c99c69..9e4fa63 100644 --- a/src/TerminationConditions.jl +++ b/src/TerminationConditions.jl @@ -6,6 +6,7 @@ using ..Aggregators: TrackedResponses using ..ConfigBase using PsychometricsBazaarBase.ConfigTools: @returnsome, find1_instance using FittedItemBanks +import Base: show export TerminationCondition, LengthTerminationCondition, SimpleFunctionTerminationCondition @@ -32,6 +33,10 @@ function (condition::LengthTerminationCondition)(responses::TrackedResponses, length(responses) >= condition.num_items end +function show(io::IO, ::MIME"text/plain", condition::LengthTerminationCondition) + println(io, "Terminate test after administering $(condition.num_items) items") +end + # Alias for old name const FixedItemsTerminationCondition = LengthTerminationCondition diff --git a/src/aggregators/Aggregators.jl b/src/aggregators/Aggregators.jl index 296eb26..4e5b967 100644 --- a/src/aggregators/Aggregators.jl +++ b/src/aggregators/Aggregators.jl @@ -30,8 +30,9 @@ using PsychometricsBazaarBase.Integrators: Integrators, IntValue, Integrator, PreallocatedFixedGridIntegrator, normdenom -using PsychometricsBazaarBase.Optimizers: OneDimOptimOptimizer, Optimizer +using PsychometricsBazaarBase.Optimizers: OneDimOptimOptimizer, Optimizer, Optimizers using PsychometricsBazaarBase.ConstDistributions: std_normal, std_mv_normal +using PsychometricsBazaarBase.IndentWrappers: indent import Distributions: pdf import Base: show @@ -209,6 +210,10 @@ function (integrator::FunctionIntegrator{IntegratorT})(f::F, integrator.integrator(FunctionProduct(f, lh_function), ncomp) end +function show(io::IO, ::MIME"text/plain", responses::FunctionIntegrator) + show(io, MIME("text/plain"), responses.integrator) +end + # Defaults const optim_tol = 1e-12 const int_tol = 1e-8 diff --git a/src/aggregators/ability_estimator.jl b/src/aggregators/ability_estimator.jl index 4739a74..ed2d204 100644 --- a/src/aggregators/ability_estimator.jl +++ b/src/aggregators/ability_estimator.jl @@ -26,6 +26,10 @@ function pdf(::LikelihoodAbilityEstimator, AbilityLikelihood(tracked_responses) end +function show(io::IO, ::MIME"text/plain", ability_estimator::LikelihoodAbilityEstimator) + println(io, "Ability likelihood distribution") +end + struct PosteriorAbilityEstimator{PriorT <: Distribution} <: DistributionAbilityEstimator prior::PriorT end @@ -57,6 +61,14 @@ function multiple_response_types_guard(tracked_responses) return false end +function show(io::IO, ::MIME"text/plain", ability_estimator::PosteriorAbilityEstimator) + println(io, "Ability posterior distribution") + indent_io = indent(io, 2) + print(indent_io, "Prior: ") + show(indent_io, MIME("text/plain"), ability_estimator.prior) + println(io) +end + struct GuardedAbilityEstimator{T <: DistributionAbilityEstimator, U <: DistributionAbilityEstimator, F} <: DistributionAbilityEstimator est::T fallback::U @@ -214,11 +226,9 @@ end function show(io::IO, ::MIME"text/plain", ability_estimator::ModeAbilityEstimator) println(io, "Estimate ability using its mode") - indent_io = indent(io, 2; skip_first=true) - print(indent_io, "Distribution estimator ") - show(indent_io, ability_estimator.dist_est) - print(indent_io, "Optimizer: ") - show(indent_io, ability_estimator.optim) + indent_io = indent(io, 2) + show(indent_io, MIME("text/plain"), ability_estimator.dist_est) + show(indent_io, MIME("text/plain"), ability_estimator.optim) end struct MeanAbilityEstimator{ @@ -236,6 +246,14 @@ function MeanAbilityEstimator(bits...) MeanAbilityEstimator(dist_est, integrator) end +function show(io::IO, ::MIME"text/plain", ability_estimator::MeanAbilityEstimator) + println(io, "Estimate ability using its mean") + indent_io = indent(io, 2) + show(indent_io, MIME("text/plain"), ability_estimator.dist_est) + print(indent_io, "Integrator: ") + show(indent_io, MIME("text/plain"), ability_estimator.integrator) +end + function distribution_estimator(dist_est::DistributionAbilityEstimator)::DistributionAbilityEstimator dist_est end diff --git a/src/aggregators/optimizers.jl b/src/aggregators/optimizers.jl index f7b2376..314b6c3 100644 --- a/src/aggregators/optimizers.jl +++ b/src/aggregators/optimizers.jl @@ -12,17 +12,17 @@ end function show(io::IO, ::MIME"text/plain", optim::FunctionOptimizer) indent_io = indent(io, 2) - if optim.optim isa OneDimOptimOptimizer || optim.optim isa MultiDimOptimOptimizer || optim.optim isa NativeOneDimOptimOptimizer + if optim.optim isa Optimizers.OneDimOptimOptimizer || optim.optim isa Optimizers.MultiDimOptimOptimizer || optim.optim isa Optimizers.NativeOneDimOptimOptimizer inner = optim.optim println(io, "Optimizer:") - if optim.optim isa NativeOneDimOptimOptimizer + if optim.optim isa Optimizers.NativeOneDimOptimOptimizer name = typeof(inner.method).name.name else name = typeof(inner.optim).name.name end - print(indent_io, "Method: ", name) - print(indent_io, "Lo: ", inner.lo) - print(indent_io, "Hi: ", inner.hi) + println(indent_io, "Method: ", name) + println(indent_io, "Lo: ", inner.lo) + println(indent_io, "Hi: ", inner.hi) end end diff --git a/src/next_item_rules/combinators/expectation.jl b/src/next_item_rules/combinators/expectation.jl index 57d7b69..7e51841 100644 --- a/src/next_item_rules/combinators/expectation.jl +++ b/src/next_item_rules/combinators/expectation.jl @@ -42,9 +42,7 @@ end function show(io::IO, ::MIME"text/plain", point_response_expectation::PointResponseExpectation) println(io, "Expected response at point ability estimate") indent_io = indent(io, 2) - print(indent_io, "Ability estimator: ") - summary(indent_io, point_response_expectation.ability_estimator) - println(io) + show(indent_io, MIME("text/plain"), point_response_expectation.ability_estimator) end struct DistributionResponseExpectation{ @@ -136,8 +134,6 @@ end function show(io::IO, ::MIME"text/plain", item_criterion::ExpectationBasedItemCriterion) println(io, "Optimize an state/item/item-category criterion based on an expected response") indent_io = indent(io, 2) - print(indent_io, "Expected response obtained by: ") show(indent_io, MIME"text/plain"(), item_criterion.response_expectation) - print(indent_io, "Criterion: ") show(indent_io, MIME"text/plain"(), item_criterion.criterion) -end \ No newline at end of file +end diff --git a/src/next_item_rules/criteria/state/ability_variance.jl b/src/next_item_rules/criteria/state/ability_variance.jl index 19f2eab..3ee74db 100644 --- a/src/next_item_rules/criteria/state/ability_variance.jl +++ b/src/next_item_rules/criteria/state/ability_variance.jl @@ -71,10 +71,8 @@ end function show(io::IO, ::MIME"text/plain", criterion::AbilityVarianceStateCriterion) println(io, "Minimise variance of ability estimate") indent_io = indent(io, 2) - print(indent_io, "Distribution estimator: ") - show(indent_io, MIME"text/plain"(), criterion.dist_est) - print(indent_io, "Integrator: ") - show(indent_io, MIME"text/plain"(), criterion.integrator) + show(indent_io, MIME("text/plain"), criterion.dist_est) + show(indent_io, MIME("text/plain"), criterion.integrator) end struct AbilityCovarianceStateMultiCriterion{ diff --git a/src/next_item_rules/prelude/next_item_rule.jl b/src/next_item_rules/prelude/next_item_rule.jl index 7b68b3f..7da78ae 100644 --- a/src/next_item_rules/prelude/next_item_rule.jl +++ b/src/next_item_rules/prelude/next_item_rule.jl @@ -56,13 +56,11 @@ end function Base.show(io::IO, ::MIME"text/plain", rule::ItemStrategyNextItemRule) println(io, "Pick optimal item criterion according to strategy") indent_io = indent(io, 2) - print(indent_io, "Strategy: ") show(indent_io, MIME"text/plain"(), rule.strategy) - print(indent_io, "Item criterion: ") show(indent_io, MIME"text/plain"(), rule.criterion) end # Default implementation function compute_criteria(::NextItemRule, ::TrackedResponses) nothing -end \ No newline at end of file +end diff --git a/src/next_item_rules/strategies/pointwise.jl b/src/next_item_rules/strategies/pointwise.jl index 058e48e..e0a5616 100644 --- a/src/next_item_rules/strategies/pointwise.jl +++ b/src/next_item_rules/strategies/pointwise.jl @@ -21,7 +21,6 @@ function show(io::IO, ::MIME"text/plain", rule::PointwiseNextItemRule) indent_io = indent(io, 2) points_desc = join(rule.points, ", ") println(indent_io, "Points: $points_desc") - print(indent_io, "Criterion: ") show(indent_io, MIME("text/plain"), rule.criterion) end From c21b14c218e6576ead9e92be73c312a66f8f1611 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Wed, 2 Jul 2025 00:56:37 +0300 Subject: [PATCH 36/42] Improve CatRecorder --- src/sim/Sim.jl | 2 ++ src/sim/recorder.jl | 56 +++++++++++++++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/src/sim/Sim.jl b/src/sim/Sim.jl index 3ad7f43..75511c9 100644 --- a/src/sim/Sim.jl +++ b/src/sim/Sim.jl @@ -1,11 +1,13 @@ module Sim +using DataFrames: DataFrame using ElasticArrays using ElasticArrays: sizehint_lastdim! using DocStringExtensions using StatsBase using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse using PsychometricsBazaarBase.Integrators +using PsychometricsBazaarBase.IndentWrappers: indent using ..ConfigBase using ..Responses using ..Rules: CatRules diff --git a/src/sim/recorder.jl b/src/sim/recorder.jl index 2b017e9..b897a88 100644 --- a/src/sim/recorder.jl +++ b/src/sim/recorder.jl @@ -20,12 +20,19 @@ Base.@kwdef mutable struct CatRecording{LikelihoodsT <: NamedTuple} #likelihoods::Matrix{Float64} #raw_likelihoods::Matrix{Float64} data::LikelihoodsT - item_responses::Vector{Float64} item_index::Vector{Int} item_correctness::Vector{Bool} rules_description::Union{Nothing, String} = nothing end +function Base.getproperty(obj::CatRecording, sym::Symbol) + if hasfield(CatRecording, sym) + return getfield(obj, sym) + else + return getproperty(obj.data, sym) + end +end + Base.@kwdef struct CatRecorder{RequestsT <: NamedTuple, LikelihoodsT <: NamedTuple} recording::CatRecording{LikelihoodsT} requests::RequestsT @@ -40,28 +47,46 @@ function CatRecording( ) CatRecording(; data=data, - item_responses=empty_capacity(Float64, expected_responses), item_index=empty_capacity(Int, expected_responses), item_correctness=empty_capacity(Bool, expected_responses) ) end +function prepare_dataframe(recording::CatRecording) + cols::Vector{Pair{String, Vector{Any}}} = [ + "Item" => recording.item_index, + "Response" => recording.item_correctness, + ] + for (name, value) in pairs(recording.data) + #@show name value.type keys(value) size(value.data) + if value.data isa AbstractVector + push!(cols, String(name) => value.data) + end + end + return DataFrame(cols) +end + function show(io::IO, ::MIME"text/plain", recording::CatRecording) println(io, "Recording of a Computer-Adaptive Test") if recording.rules_description === nothing println(io, " Unknown CAT configuration") else println(io, " CAT configuration:") - for line in split(recording.rules_description, "\n") + for line in split(strip(recording.rules_description, '\n'), "\n") println(io, " ", line) end end - println(io, " item_responses: ", length(recording.item_responses)) - println(io, " item_index: ", length(recording.item_index)) - println(io, " item_correctness: ", length(recording.item_correctness)) - for (name, data) in pairs(recording.data) - println(io, " $name: ", size(data.data)) + println(io) + println(io, " Recorded information:") + df = prepare_dataframe(recording) + buf = IOBuffer() + show(buf, MIME("text/plain"), df; summary=false, eltypes=false, rowlabel=:Number) + seekstart(buf) + for line in eachline(buf) + println(io, " ", line) end + #println(io) + #println(io, " Final information:") end #= @@ -226,11 +251,16 @@ function CatRecorder( end =# +function name_to_label(name) + titlecase(join(split(String(name), "_"), " ")) +end + function CatRecorder(dims::Int, expected_responses::Int; requests...) out = [] sizehint!(out, length(requests)) for (name, request) in pairs(requests) - if request.type == :ability_value + extra = (;) + if request.type in (:ability, :ability_stddev) data = empty_capacity(Float64, expected_responses) elseif request.type == :ability_distribution if dims == 0 @@ -238,10 +268,13 @@ function CatRecorder(dims::Int, expected_responses::Int; requests...) else data = empty_capacity(Float64, dims, length(request.points), expected_responses) end + extra = (; points = request.points) end push!(out, (name => (; + label=haskey(request, :label) ? request.label : name_to_label(name), type=request.type, - data=data, + data, + extra... ))) end return CatRecorder(; @@ -332,11 +365,10 @@ function service_requests!( ) out = recorder.recording.data for (name, request) in pairs(recorder.requests) - if request.type == :ability_value + if request.type in (:ability, :ability_stddev) push!(out[name].data, request.estimator(tracked_responses)) elseif request.type == :ability_distribution likelihood_sample = sample_likelihood(tracked_responses, request.points, request.estimator, request.integrator) - @info "pushing" name size(out[name].data) size(likelihood_sample) elastic_push!(out[name].data, likelihood_sample) end end From 7a625097b971caf37c37a24f8dbde2a5f79cd2e0 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 5 Jul 2025 00:54:18 +0300 Subject: [PATCH 37/42] More succinct naming --- benchmark/benchmarks.jl | 4 +-- .../examples/ability_convergence_3pl.jl | 8 ++--- .../examples/ability_convergence_mirt.jl | 4 +-- docs/examples/examples/vocab_iq.jl | 4 +-- docs/src/creating_a_cat.md | 10 +++--- docs/src/stateful.md | 2 +- src/Compat/CatR.jl | 16 ++++----- src/Compat/MirtCAT.jl | 10 +++--- src/Stateful.jl | 34 ++++++++++--------- src/TerminationConditions.jl | 17 ++++------ src/next_item_rules/NextItemRules.jl | 6 ++-- .../criteria/state/ability_variance.jl | 16 ++++----- src/next_item_rules/prelude/abstract.jl | 8 ++--- src/next_item_rules/prelude/criteria.jl | 4 +-- src/next_item_rules/prelude/next_item_rule.jl | 12 +++---- src/next_item_rules/strategies/exhaustive.jl | 2 +- src/next_item_rules/strategies/randomesque.jl | 2 +- src/next_item_rules/strategies/sequential.jl | 2 +- src/precompiles.jl | 12 +++---- test/ability_estimator_1d.jl | 2 +- test/compat.jl | 2 +- test/dt.jl | 6 ++-- test/dummy.jl | 2 +- test/smoke.jl | 2 +- test/stateful.jl | 16 ++++----- 25 files changed, 101 insertions(+), 102 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index ec04dcf..e10f2f9 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -38,10 +38,10 @@ function prepare_4pls(group) response_idxs = sample(rng, 1:20, 10) for (est_nick, ability_estimator) in ability_estimators - next_item_rule = ItemStrategyNextItemRule( + next_item_rule = ItemCriterionRule( ExhaustiveSearch(), ExpectationBasedItemCriterion(PointResponseExpectation(ability_estimator), - AbilityVarianceStateCriterion( + AbilityVariance( integrator, distribution_estimator(ability_estimator))) ) next_item_rule = preallocate(next_item_rule) diff --git a/docs/examples/examples/ability_convergence_3pl.jl b/docs/examples/examples/ability_convergence_3pl.jl index 744fb8b..b31b1f0 100644 --- a/docs/examples/examples/ability_convergence_3pl.jl +++ b/docs/examples/examples/ability_convergence_3pl.jl @@ -20,8 +20,8 @@ using Distributions: Normal, cdf using AlgebraOfGraphics using ComputerAdaptiveTesting using ComputerAdaptiveTesting.Sim: auto_responder -using ComputerAdaptiveTesting.NextItemRules: AbilityVarianceStateCriterion -using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition +using ComputerAdaptiveTesting.NextItemRules: AbilityVariance +using ComputerAdaptiveTesting.TerminationConditions: FixedLength using ComputerAdaptiveTesting.Aggregators: PosteriorAbilityEstimator, MeanAbilityEstimator, LikelihoodAbilityEstimator using FittedItemBanks @@ -49,8 +49,8 @@ integrator = FixedGKIntegrator(-6, 6, 80) dist_ability_est = PosteriorAbilityEstimator(std_normal) ability_estimator = MeanAbilityEstimator(dist_ability_est, integrator) rules = CatRules(ability_estimator, - AbilityVarianceStateCriterion(dist_ability_est, integrator), - FixedItemsTerminationCondition(max_questions)) + AbilityVariance(dist_ability_est, integrator), + FixedLength(max_questions)) points = 500 xs = range(-2.5, 2.5, length = points) diff --git a/docs/examples/examples/ability_convergence_mirt.jl b/docs/examples/examples/ability_convergence_mirt.jl index adb480f..b8cb144 100644 --- a/docs/examples/examples/ability_convergence_mirt.jl +++ b/docs/examples/examples/ability_convergence_mirt.jl @@ -21,7 +21,7 @@ using AlgebraOfGraphics using ComputerAdaptiveTesting using ComputerAdaptiveTesting.Sim: auto_responder using ComputerAdaptiveTesting.NextItemRules: DRuleItemCriterion -using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition +using ComputerAdaptiveTesting.TerminationConditions: FixedLength using ComputerAdaptiveTesting.Aggregators: PosteriorAbilityEstimator, MeanAbilityEstimator, LikelihoodAbilityEstimator using FittedItemBanks @@ -53,7 +53,7 @@ ability_estimator = MeanAbilityEstimator(PosteriorAbilityEstimator(std_mv_normal integrator) rules = CatRules(ability_estimator, DRuleItemCriterion(ability_estimator), - FixedItemsTerminationCondition(max_questions)) + FixedLength(max_questions)) # XXX: We shouldn't need to specify xs here since the distributions are not used -- rework points = 3 diff --git a/docs/examples/examples/vocab_iq.jl b/docs/examples/examples/vocab_iq.jl index 0434daf..009217d 100644 --- a/docs/examples/examples/vocab_iq.jl +++ b/docs/examples/examples/vocab_iq.jl @@ -42,8 +42,8 @@ function run_vocab_iq_cat() ability_estimator = ModeAbilityEstimator(dist_ability_est, optimizer) @info "run_cat" ability_estimator rules = CatRules(ability_estimator, - AbilityVarianceStateCriterion(dist_ability_est, ability_integrator), - FixedItemsTerminationCondition(45)) + AbilityVariance(dist_ability_est, ability_integrator), + FixedLength(45)) function get_response(response_idx, response_name) params = item_params(item_bank, response_idx) println("Parameters for next question: $params") diff --git a/docs/src/creating_a_cat.md b/docs/src/creating_a_cat.md index ba30b89..35927bd 100644 --- a/docs/src/creating_a_cat.md +++ b/docs/src/creating_a_cat.md @@ -79,13 +79,13 @@ ComputerAdaptiveTesting.NextItemRules.RandomNextItemRule Other rules are created by combining a `ItemCriterion` -- which somehow rates items according to how good they are -- with a `NextItemStrategy` using an -`ItemStrategyNextItemRule`, which acts as an adapter. The default +`ItemCriterionRule`, which acts as an adapter. The default `NextItemStrategy` (and currently only) is `ExhaustiveSearch`. When using the implicit constructors, `ItemCriterion` can therefore be used directly without wrapping in any place an NextItemRule is expected. ```@docs; canonical=false -ComputerAdaptiveTesting.NextItemRules.ItemStrategyNextItemRule +ComputerAdaptiveTesting.NextItemRules.ItemCriterionRule ``` ```@docs; canonical=false @@ -114,17 +114,17 @@ takes a `ResponseExpectation`: either `PointResponseExpectation` or good a particular state is in terms getting a good estimate of the test takers ability. They look one ply ahead to get the expected value of the ``StateCriterion`` after selecting the given item. The -`AbilityVarianceStateCriterion` looks at the variance of the ability ``\theta`` +`AbilityVariance` looks at the variance of the ability ``\theta`` estimate at that state. ### Stopping rules with `TerminationCondition` -Currently the only `TerminationCondition` is `FixedItemsTerminationCondition`, which ends the test after a fixed number of items. +Currently the only `TerminationCondition` is `FixedLength`, which ends the test after a fixed number of items. ```@docs; canonical=false ComputerAdaptiveTesting.TerminationConditions.TerminationCondition ``` ```@docs; canonical=false -ComputerAdaptiveTesting.TerminationConditions.FixedItemsTerminationCondition +ComputerAdaptiveTesting.TerminationConditions.FixedLength ``` diff --git a/docs/src/stateful.md b/docs/src/stateful.md index aaa24eb..0fea784 100644 --- a/docs/src/stateful.md +++ b/docs/src/stateful.md @@ -28,7 +28,7 @@ Stateful.get_ability There is an implementation in terms of [CatRules](@ref): ```@docs -Stateful.StatefulCatConfig +Stateful.StatefulCatRules ``` ## Usage diff --git a/src/Compat/CatR.jl b/src/Compat/CatR.jl index 9da15c3..c17cb6e 100644 --- a/src/Compat/CatR.jl +++ b/src/Compat/CatR.jl @@ -6,7 +6,7 @@ using ComputerAdaptiveTesting.Aggregators: AbilityIntegrator, ModeAbilityEstimator, MeanAbilityEstimator, PosteriorAbilityEstimator -using ComputerAdaptiveTesting.TerminationConditions: RunForeverTerminationCondition +using ComputerAdaptiveTesting.TerminationConditions: RunForever using ComputerAdaptiveTesting.Rules: CatRules using ComputerAdaptiveTesting.NextItemRules using PsychometricsBazaarBase: Integrators, Optimizers @@ -19,15 +19,15 @@ function _next_item_aliases() "MFI" => InformationItemCriterion, "bOpt" => UrryItemCriterion, ) - res[nick] = (bits...; kwargs...) -> ItemStrategyNextItemRule( + res[nick] = (bits...; kwargs...) -> ItemCriterionRule( ExhaustiveSearch(), mk_item_criterion(bits...)) end - res["MEPV"] = (bits...; posterior_ability_estimator, kwargs...) -> ItemStrategyNextItemRule( + res["MEPV"] = (bits...; posterior_ability_estimator, kwargs...) -> ItemCriterionRule( ExhaustiveSearch(), ExpectationBasedItemCriterion(bits..., - AbilityVarianceStateCriterion(posterior_ability_estimator, AbilityIntegrator(bits...)))) - res["MEI"] = (bits...; kwargs...) -> ItemStrategyNextItemRule( + AbilityVariance(posterior_ability_estimator, AbilityIntegrator(bits...)))) + res["MEI"] = (bits...; kwargs...) -> ItemCriterionRule( ExhaustiveSearch(), ExpectationBasedItemCriterion(bits..., InformationItemCriterion(bits...))) @@ -64,7 +64,7 @@ const ability_estimator_aliases = _ability_estimator_aliases() #= for (resp_exp, resp_exp_nick) in resp_exp_nick_pairs next_item_rule = NextItemRule( - ExpectationBasedItemCriterion(resp_exp, AbilityVarianceStateCriterion(numtools.integrator, distribution_estimator(abil_est))) + ExpectationBasedItemCriterion(resp_exp, AbilityVariance(numtools.integrator, distribution_estimator(abil_est))) ) next_item_rule = preallocate(next_item_rule) est_next_item_rule_pairs[Symbol("$(abil_est_str)_mepv_$(resp_exp_nick)")] = (abil_est, next_item_rule) @@ -99,10 +99,10 @@ function assemble_rules(; ability_estimator = ability_estimator_aliases[method](; integrator, optimizer) posterior_ability_estimator = PosteriorAbilityEstimator() raw_next_item = next_item_aliases[criterion](ability_estimator, integrator, optimizer; posterior_ability_estimator=posterior_ability_estimator) - next_item = FixedFirstItemNextItemRule(start_item, raw_next_item) + next_item = FixedFirstItem(start_item, raw_next_item) CatRules(; next_item, - termination_condition = RunForeverTerminationCondition(), + termination_condition = RunForever(), ability_estimator, #ability_tracker::AbilityTrackerT = NullAbilityTracker() ) diff --git a/src/Compat/MirtCAT.jl b/src/Compat/MirtCAT.jl index ae2865b..491fb24 100644 --- a/src/Compat/MirtCAT.jl +++ b/src/Compat/MirtCAT.jl @@ -8,7 +8,7 @@ using ComputerAdaptiveTesting.Aggregators: SafeLikelihoodAbilityEstimator, PosteriorAbilityEstimator, AbilityEstimator, distribution_estimator -using ComputerAdaptiveTesting.TerminationConditions: RunForeverTerminationCondition +using ComputerAdaptiveTesting.TerminationConditions: RunForever using ComputerAdaptiveTesting.NextItemRules using ComputerAdaptiveTesting.Rules: CatRules using PsychometricsBazaarBase: Integrators, Optimizers @@ -23,7 +23,7 @@ function _next_item_helper(item_criterion_callback) optimizer, ] item_criterion = item_criterion_callback(; bits, ability_estimator, posterior_ability_estimator, integrator, optimizer) - return ItemStrategyNextItemRule(ExhaustiveSearch(), item_criterion) + return ItemCriterionRule(ExhaustiveSearch(), item_criterion) end return _helper end @@ -34,7 +34,7 @@ const next_item_aliases = Dict( # 'MEPV' for minimum expected posterior variance "MEPV" => _next_item_helper((; bits, ability_estimator, posterior_ability_estimator, integrator, rest...) -> ExpectationBasedItemCriterion( ability_estimator, - AbilityVarianceStateCriterion(posterior_ability_estimator, integrator))), + AbilityVariance(posterior_ability_estimator, integrator))), "MEI" => _next_item_helper((; bits, ability_estimator, rest...) -> ExpectationBasedItemCriterion( ability_estimator, PointItemCategoryCriterion(EmpiricalInformationPointwiseItemCategoryCriterion(), ability_estimator) @@ -145,11 +145,11 @@ function assemble_rules(; ability_estimator = ability_estimator_aliases[method](; integrator, optimizer, ncomp) posterior_ability_estimator = PosteriorAbilityEstimator(; ncomp) raw_next_item = next_item_aliases[criteria](ability_estimator, posterior_ability_estimator, integrator, optimizer) - next_item = FixedFirstItemNextItemRule(start_item, raw_next_item) + next_item = FixedFirstItem(start_item, raw_next_item) CatRules(; next_item, ability_estimator, - termination_condition = RunForeverTerminationCondition(), + termination_condition = RunForever(), ) end diff --git a/src/Stateful.jl b/src/Stateful.jl index 589651f..f8482e6 100644 --- a/src/Stateful.jl +++ b/src/Stateful.jl @@ -14,7 +14,7 @@ using ..Responses: BareResponses, Response, Responses using ..NextItemRules: compute_criteria, best_item using ..Sim: CatLoop, Sim, item_label -export StatefulCat, StatefulCatConfig +export StatefulCat, StatefulCatRules public next_item, ranked_items, item_criteria public add_response!, rollback!, reset!, get_responses, get_ability @@ -190,26 +190,28 @@ $(TYPEDSIGNATURES) This is a the `StatefulCat` implementation in terms of `CatRules`. It is also the de-facto standard for the behavior of the interface. """ -struct StatefulCatConfig{TrackedResponsesT <: TrackedResponses} <: StatefulCat +struct StatefulCatRules{TrackedResponsesT <: TrackedResponses} <: StatefulCat rules::CatRules tracked_responses::Ref{TrackedResponsesT} end -function StatefulCatConfig(rules::CatRules, item_bank::AbstractItemBank) +function StatefulCatRules(rules::CatRules, item_bank::AbstractItemBank) bare_responses = BareResponses(ResponseType(item_bank)) tracked_responses = TrackedResponses( bare_responses, item_bank, rules.ability_tracker ) - return StatefulCatConfig(rules, Ref(tracked_responses)) + return StatefulCatRules(rules, Ref(tracked_responses)) end -function next_item(config::StatefulCatConfig) +StatefulCat(rules::CatRules, item_bank::AbstractItemBank) = StatefulCatRules(rules, item_bank) + +function next_item(config::StatefulCatRules) return best_item(config.rules.next_item, config.tracked_responses[]) end -function ranked_items(config::StatefulCatConfig) +function ranked_items(config::StatefulCatRules) criteria = compute_criteria( config.rules.next_item, config.tracked_responses[]) if criteria === nothing @@ -218,27 +220,27 @@ function ranked_items(config::StatefulCatConfig) return sortperm(criteria) end -function item_criteria(config::StatefulCatConfig) +function item_criteria(config::StatefulCatRules) return compute_criteria( config.rules.next_item, config.tracked_responses[]) end -function add_response!(config::StatefulCatConfig, index, response) +function add_response!(config::StatefulCatRules, index, response) tracked_responses = config.tracked_responses[] Responses.add_response!( tracked_responses, Response( ResponseType(tracked_responses.item_bank), index, response)) end -function rollback!(config::StatefulCatConfig) +function rollback!(config::StatefulCatRules) Responses.pop_response!(config.tracked_responses[]) end -function reset!(config::StatefulCatConfig) +function reset!(config::StatefulCatRules) empty!(config.tracked_responses[]) end -function set_item_bank!(config::StatefulCatConfig, item_bank) +function set_item_bank!(config::StatefulCatRules, item_bank) bare_responses = BareResponses(ResponseType(item_bank)) config.tracked_responses[] = TrackedResponses( bare_responses, @@ -247,23 +249,23 @@ function set_item_bank!(config::StatefulCatConfig, item_bank) ) end -function get_responses(config::StatefulCatConfig) +function get_responses(config::StatefulCatRules) return config.tracked_responses[].responses end -function get_ability(config::StatefulCatConfig) +function get_ability(config::StatefulCatRules) return (config.rules.ability_estimator(config.tracked_responses[]), nothing) end -function likelihood(config::StatefulCatConfig, ability) +function likelihood(config::StatefulCatRules, ability) pdf(distribution_estimator(config.rules.ability_estimator), config.tracked_responses[], ability) end -function item_bank_size(config::StatefulCatConfig) +function item_bank_size(config::StatefulCatRules) return length(config.tracked_responses[].item_bank) end -function item_response_functions(config::StatefulCatConfig, index, ability) +function item_response_functions(config::StatefulCatRules, index, ability) item_bank = config.tracked_responses[].item_bank item_response = ItemResponse(item_bank, index) return resp_vec(item_response, ability) diff --git a/src/TerminationConditions.jl b/src/TerminationConditions.jl index 9e4fa63..3ab7fe8 100644 --- a/src/TerminationConditions.jl +++ b/src/TerminationConditions.jl @@ -9,8 +9,8 @@ using FittedItemBanks import Base: show export TerminationCondition, - LengthTerminationCondition, SimpleFunctionTerminationCondition -export RunForeverTerminationCondition + FixedLength, SimpleFunctionTerminationCondition +export RunForever """ $(TYPEDEF) @@ -25,21 +25,18 @@ end $(TYPEDEF) $(TYPEDFIELDS) """ -struct LengthTerminationCondition{} <: TerminationCondition +struct FixedLength{} <: TerminationCondition num_items::Int64 end -function (condition::LengthTerminationCondition)(responses::TrackedResponses, +function (condition::FixedLength)(responses::TrackedResponses, items::AbstractItemBank) length(responses) >= condition.num_items end -function show(io::IO, ::MIME"text/plain", condition::LengthTerminationCondition) +function show(io::IO, ::MIME"text/plain", condition::FixedLength) println(io, "Terminate test after administering $(condition.num_items) items") end -# Alias for old name -const FixedItemsTerminationCondition = LengthTerminationCondition - struct SimpleFunctionTerminationCondition{F} <: TerminationCondition func::F end @@ -48,8 +45,8 @@ function (condition::SimpleFunctionTerminationCondition)(responses::TrackedRespo condition.func(responses, items) end -struct RunForeverTerminationCondition <: TerminationCondition end -function (condition::RunForeverTerminationCondition)(::TrackedResponses, ::AbstractItemBank) +struct RunForever <: TerminationCondition end +function (condition::RunForever)(::TrackedResponses, ::AbstractItemBank) return false end diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index ec04fb6..def1dfb 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -37,8 +37,8 @@ using ConstructionBase: constructorof import ForwardDiff import Base: show -export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread -export NextItemRule, ItemStrategyNextItemRule +export ExpectationBasedItemCriterion, AbilityVariance, init_thread +export NextItemRule, ItemCriterionRule export UrryItemCriterion, InformationItemCriterion export LikelihoodWeightedItemCriterion, PointItemCriterion export LikelihoodWeightedItemCategoryCriterion, PointItemCategoryCriterion @@ -47,7 +47,7 @@ export RawEmpiricalInformationPointwiseItemCategoryCriterion export EmpiricalInformationPointwiseItemCategoryCriterion export TotalItemInformation export RandomNextItemRule -export FixedRuleSequencer, MemoryNextItemRule, FixedFirstItemNextItemRule +export FixedRuleSequencer, MemoryNextItemRule, FixedFirstItem export ExhaustiveSearch, RandomesqueStrategy export preallocate export compute_criteria, compute_criterion, compute_multi_criterion diff --git a/src/next_item_rules/criteria/state/ability_variance.jl b/src/next_item_rules/criteria/state/ability_variance.jl index 3ee74db..47d7232 100644 --- a/src/next_item_rules/criteria/state/ability_variance.jl +++ b/src/next_item_rules/criteria/state/ability_variance.jl @@ -5,7 +5,7 @@ $(TYPEDFIELDS) This `StateCriterion` returns the variance of the ability estimate given a set of responses. """ -struct AbilityVarianceStateCriterion{ +struct AbilityVariance{ DistEst <: DistributionAbilityEstimator, IntegratorT <: AbilityIntegrator } <: StateCriterion @@ -14,15 +14,15 @@ struct AbilityVarianceStateCriterion{ skip_zero::Bool end -function AbilityVarianceStateCriterion(bits...) +function AbilityVariance(bits...) skip_zero = false - @returnsome find1_instance(AbilityVarianceStateCriterion, bits) + @returnsome find1_instance(AbilityVariance, bits) @requiresome dist_est_integrator_pair = get_dist_est_and_integrator(bits...) (dist_est, integrator) = dist_est_integrator_pair - return AbilityVarianceStateCriterion(dist_est, integrator, skip_zero) + return AbilityVariance(dist_est, integrator, skip_zero) end -function compute_criterion(criterion::AbilityVarianceStateCriterion, +function compute_criterion(criterion::AbilityVariance, tracked_responses::TrackedResponses)::Float64 # XXX: Not sure if the estimator should come from somewhere else here denom = normdenom(criterion.integrator, @@ -35,7 +35,7 @@ function compute_criterion(criterion::AbilityVarianceStateCriterion, criterion, DomainType(tracked_responses.item_bank), tracked_responses, denom) end -function compute_criterion(criterion::AbilityVarianceStateCriterion, +function compute_criterion(criterion::AbilityVariance, ::Union{OneDimContinuousDomain, DiscreteDomain}, tracked_responses::TrackedResponses, denom)::Float64 @@ -48,7 +48,7 @@ function compute_criterion(criterion::AbilityVarianceStateCriterion, end function compute_criterion( - criterion::AbilityVarianceStateCriterion, + criterion::AbilityVariance, ::Vector, tracked_responses::TrackedResponses, denom @@ -68,7 +68,7 @@ function compute_criterion( denom) end -function show(io::IO, ::MIME"text/plain", criterion::AbilityVarianceStateCriterion) +function show(io::IO, ::MIME"text/plain", criterion::AbilityVariance) println(io, "Minimise variance of ability estimate") indent_io = indent(io, 2) show(indent_io, MIME("text/plain"), criterion.dist_est) diff --git a/src/next_item_rules/prelude/abstract.jl b/src/next_item_rules/prelude/abstract.jl index 5679712..52ee8ae 100644 --- a/src/next_item_rules/prelude/abstract.jl +++ b/src/next_item_rules/prelude/abstract.jl @@ -6,21 +6,21 @@ Abstract base type for all item selection rules. All descendants of this type are expected to implement the interface `(::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int`. -In practice, all adaptive rules in this package use `ItemStrategyNextItemRule`. +In practice, all adaptive rules in this package use `ItemCriterionRule`. $(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true) Implicit constructor for $(FUNCTIONNAME). Uses any given `NextItemRule` or -delegates to `ItemStrategyNextItemRule` the default instance. +delegates to `ItemCriterionRule` the default instance. """ abstract type NextItemRule <: CatConfigBase end """ $(TYPEDEF) -Abstract type for next item strategies, tightly coupled with `ItemStrategyNextItemRule`. +Abstract type for next item strategies, tightly coupled with `ItemCriterionRule`. All descendants of this type are expected to implement the interface -`(rule::ItemStrategyNextItemRule{::NextItemStrategy, ::ItemCriterion})(responses::TrackedResponses, +`(rule::ItemCriterionRule{::NextItemStrategy, ::ItemCriterion})(responses::TrackedResponses, items) where {ItemCriterionT <: } `(strategy::NextItemStrategy)(; parallel=true)::NextItemStrategy` """ diff --git a/src/next_item_rules/prelude/criteria.jl b/src/next_item_rules/prelude/criteria.jl index 31dede9..f64c8ea 100644 --- a/src/next_item_rules/prelude/criteria.jl +++ b/src/next_item_rules/prelude/criteria.jl @@ -75,7 +75,7 @@ function compute_criteria( end function compute_criteria( - rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT}, + rule::ItemCriterionRule{StrategyT, ItemCriterionT}, responses, items ) where {StrategyT, ItemCriterionT <: ItemCriterion} @@ -83,7 +83,7 @@ function compute_criteria( end function compute_criteria( - rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT}, + rule::ItemCriterionRule{StrategyT, ItemCriterionT}, responses::TrackedResponses ) where {StrategyT, ItemCriterionT <: ItemCriterion} compute_criteria(rule.criterion, responses) diff --git a/src/next_item_rules/prelude/next_item_rule.jl b/src/next_item_rules/prelude/next_item_rule.jl index 7da78ae..c61836c 100644 --- a/src/next_item_rules/prelude/next_item_rule.jl +++ b/src/next_item_rules/prelude/next_item_rule.jl @@ -2,7 +2,7 @@ function NextItemRule(bits...; ability_estimator = nothing, ability_tracker = nothing) @returnsome find1_instance(NextItemRule, bits) - @returnsome ItemStrategyNextItemRule(bits..., + @returnsome ItemCriterionRule(bits..., ability_estimator = ability_estimator, ability_tracker = ability_tracker) end @@ -21,7 +21,7 @@ end $(TYPEDEF) $(TYPEDFIELDS) -`ItemStrategyNextItemRule` which together with a `NextItemStrategy` acts as an +`ItemCriterionRule` which together with a `NextItemStrategy` acts as an adapter by which an `ItemCriterion` can serve as a `NextItemRule`. $(FUNCTIONNAME)(bits...; ability_estimator=nothing @@ -29,7 +29,7 @@ adapter by which an `ItemCriterion` can serve as a `NextItemRule`. Implicit constructor for $(FUNCTIONNAME). Will default to `ExhaustiveSearch` when no `NextItemStrategy` is given. """ -struct ItemStrategyNextItemRule{ +struct ItemCriterionRule{ NextItemStrategyT <: NextItemStrategy, ItemCriterionT <: ItemCriterion } <: NextItemRule @@ -37,7 +37,7 @@ struct ItemStrategyNextItemRule{ criterion::ItemCriterionT end -function ItemStrategyNextItemRule(bits...; +function ItemCriterionRule(bits...; ability_estimator = nothing, ability_tracker = nothing) strategy = NextItemStrategy(bits...) @@ -45,7 +45,7 @@ function ItemStrategyNextItemRule(bits...; ability_estimator = ability_estimator, ability_tracker = ability_tracker) if strategy !== nothing && criterion !== nothing - return ItemStrategyNextItemRule(strategy, criterion) + return ItemCriterionRule(strategy, criterion) end end @@ -53,7 +53,7 @@ function best_item(rule::NextItemRule, tracked_responses::TrackedResponses) best_item(rule, tracked_responses, tracked_responses.item_bank) end -function Base.show(io::IO, ::MIME"text/plain", rule::ItemStrategyNextItemRule) +function Base.show(io::IO, ::MIME"text/plain", rule::ItemCriterionRule) println(io, "Pick optimal item criterion according to strategy") indent_io = indent(io, 2) show(indent_io, MIME"text/plain"(), rule.strategy) diff --git a/src/next_item_rules/strategies/exhaustive.jl b/src/next_item_rules/strategies/exhaustive.jl index 849b91b..c550b8c 100644 --- a/src/next_item_rules/strategies/exhaustive.jl +++ b/src/next_item_rules/strategies/exhaustive.jl @@ -42,7 +42,7 @@ $(TYPEDFIELDS) struct ExhaustiveSearch <: NextItemStrategy end function best_item( - rule::ItemStrategyNextItemRule{ExhaustiveSearch, ItemCriterionT}, + rule::ItemCriterionRule{ExhaustiveSearch, ItemCriterionT}, responses::TrackedResponses, items ) where {ItemCriterionT <: ItemCriterion} diff --git a/src/next_item_rules/strategies/randomesque.jl b/src/next_item_rules/strategies/randomesque.jl index 4ae7e11..b3e0ac5 100644 --- a/src/next_item_rules/strategies/randomesque.jl +++ b/src/next_item_rules/strategies/randomesque.jl @@ -46,7 +46,7 @@ end RandomesqueStrategy(k::Int) = RandomesqueStrategy(Xoshiro(), k) function best_item( - rule::ItemStrategyNextItemRule{RandomesqueStrategy, ItemCriterionT}, + rule::ItemCriterionRule{RandomesqueStrategy, ItemCriterionT}, responses::TrackedResponses, items ) where {ItemCriterionT <: ItemCriterion} diff --git a/src/next_item_rules/strategies/sequential.jl b/src/next_item_rules/strategies/sequential.jl index 1c0c59c..7f98653 100644 --- a/src/next_item_rules/strategies/sequential.jl +++ b/src/next_item_rules/strategies/sequential.jl @@ -63,6 +63,6 @@ function show(io::IO, ::MIME"text/plain", rule::MemoryNextItemRule) println(io, "Present the items indexed: $item_list") end -function FixedFirstItemNextItemRule(item_idx::Int, rule::NextItemRule) +function FixedFirstItem(item_idx::Int, rule::NextItemRule) FixedRuleSequencer((1,), (MemoryNextItemRule((item_idx,)), rule)) end \ No newline at end of file diff --git a/src/precompiles.jl b/src/precompiles.jl index 5147e9c..c05ef84 100644 --- a/src/precompiles.jl +++ b/src/precompiles.jl @@ -7,8 +7,8 @@ using PrecompileTools: @compile_workload, @setup_workload using Random: default_rng using .Aggregators: LikelihoodAbilityEstimator, MeanAbilityEstimator, GriddedAbilityTracker, AbilityIntegrator - using .NextItemRules: preallocate, ExhaustiveSearch, ItemStrategyNextItemRule, - ExpectationBasedItemCriterion, AbilityVarianceStateCriterion + using .NextItemRules: preallocate, ExhaustiveSearch, ItemCriterionRule, + ExpectationBasedItemCriterion, AbilityVariance using .Stateful: Stateful using .ComputerAdaptiveTesting: CatRules @@ -21,13 +21,13 @@ using PrecompileTools: @compile_workload, @setup_workload lh_grid_tracker = GriddedAbilityTracker(lh_ability_est, integrator) ability_integrator = AbilityIntegrator(integrator, lh_grid_tracker) ability_estimator = MeanAbilityEstimator(lh_ability_est, ability_integrator) - next_item_rule = ItemStrategyNextItemRule( + next_item_rule = ItemCriterionRule( ExhaustiveSearch(), ExpectationBasedItemCriterion(ability_estimator, - AbilityVarianceStateCriterion(ability_estimator))) - cat = Stateful.StatefulCatConfig(CatRules(; + AbilityVariance(ability_estimator))) + cat = Stateful.StatefulCatRules(CatRules(; next_item=next_item_rule, - termination_condition=TerminationConditions.RunForeverTerminationCondition(), + termination_condition=TerminationConditions.RunForever(), ability_estimator=ability_estimator ), item_bank) Stateful.add_response!(cat, 1, 0) diff --git a/test/ability_estimator_1d.jl b/test/ability_estimator_1d.jl index 0869c9c..c76bea0 100644 --- a/test/ability_estimator_1d.jl +++ b/test/ability_estimator_1d.jl @@ -72,7 +72,7 @@ mle_mode_1d = ModeAbilityEstimator(lh_est_1d, optimizer_1d) ) end - ability_variance_state_criterion = AbilityVarianceStateCriterion( + ability_variance_state_criterion = AbilityVariance( lh_est_1d, integrator_1d) ability_variance_item_criterion = ExpectationBasedItemCriterion( mle_mean_1d, diff --git a/test/compat.jl b/test/compat.jl index 54a6e9e..984e7ba 100644 --- a/test/compat.jl +++ b/test/compat.jl @@ -3,7 +3,7 @@ using FittedItemBanks: OneDimContinuousDomain, SimpleItemBankSpec, StdModel3PL, BooleanResponse using ComputerAdaptiveTesting.Aggregators: TrackedResponses, NullAbilityTracker - using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition + using ComputerAdaptiveTesting.TerminationConditions: FixedLength using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule using ComputerAdaptiveTesting.Responses: BareResponses, ResponseType using ComputerAdaptiveTesting: Stateful diff --git a/test/dt.jl b/test/dt.jl index cf5daa1..65d0660 100644 --- a/test/dt.jl +++ b/test/dt.jl @@ -9,12 +9,12 @@ ability_estimator = MeanAbilityEstimator(LikelihoodAbilityEstimator(), integrato get_response = auto_responder(@view true_responses[:, 1]) @testset "decision tree round trip" begin - next_item_rule = ItemStrategyNextItemRule( - AbilityVarianceStateCriterion( + next_item_rule = ItemCriterionRule( + AbilityVariance( distribution_estimator(ability_estimator), integrator), ability_estimator = ability_estimator ) - termination_condition = FixedItemsTerminationCondition(4) + termination_condition = FixedLength(4) cat_rules = CatRules( next_item = next_item_rule, diff --git a/test/dummy.jl b/test/dummy.jl index 20e644e..6d52a53 100644 --- a/test/dummy.jl +++ b/test/dummy.jl @@ -38,7 +38,7 @@ const ability_estimators_1d = [ ] const criteria_1d = [ ((:integrator, :est), - (stuff) -> AbilityVarianceStateCriterion( + (stuff) -> AbilityVariance( distribution_estimator(stuff.est), stuff.integrator)), ((:est,), (stuff) -> InformationItemCriterion(stuff.est)), ((:est,), (stuff) -> UrryItemCriterion(stuff.est)), diff --git a/test/smoke.jl b/test/smoke.jl index d63402a..f5b1891 100644 --- a/test/smoke.jl +++ b/test/smoke.jl @@ -22,7 +22,7 @@ using .Dummy function test1d(ability_estimator, bits...) rules = CatRules( - FixedItemsTerminationCondition(2), + FixedLength(2), ability_estimator, bits... ) diff --git a/test/stateful.jl b/test/stateful.jl index 5e1ff80..19f0b54 100644 --- a/test/stateful.jl +++ b/test/stateful.jl @@ -3,7 +3,7 @@ using FittedItemBanks.DummyData: dummy_full using FittedItemBanks: OneDimContinuousDomain, SimpleItemBankSpec, StdModel3PL, BooleanResponse - using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition + using ComputerAdaptiveTesting.TerminationConditions: FixedLength using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule using ComputerAdaptiveTesting: Stateful using ComputerAdaptiveTesting: require_testext @@ -23,15 +23,15 @@ num_testees = 2 ) - @testset "StatefulCatConfig basic usage" begin + @testset "StatefulCatRules basic usage" begin rules = CatRules( - FixedItemsTerminationCondition(2), + FixedLength(2), Dummy.DummyAbilityEstimator(0.0), RandomNextItemRule() ) # Initialize config - cat_config = Stateful.StatefulCatConfig(rules, item_bank) + cat_config = Stateful.StatefulCatRules(rules, item_bank) # Test initialization state @test isempty(Stateful.get_responses(cat_config)) @@ -53,11 +53,11 @@ @testset "Stateful next item selection" begin rules = CatRules( - FixedItemsTerminationCondition(2), + FixedLength(2), Dummy.DummyAbilityEstimator(0.0), RandomNextItemRule() ) - cat_config = Stateful.StatefulCatConfig(rules, item_bank) + cat_config = Stateful.StatefulCatRules(rules, item_bank) # Test first item selection first_item = Stateful.next_item(cat_config) @@ -72,13 +72,13 @@ @testset "Standard interface tests" begin rules = CatRules( - FixedItemsTerminationCondition(2), + FixedLength(2), Dummy.DummyAbilityEstimator(0.0), RandomNextItemRule() ) # Initialize config - cat_config = Stateful.StatefulCatConfig(rules, item_bank) + cat_config = Stateful.StatefulCatRules(rules, item_bank) # Run the standard interface tests TestExt = require_testext() From 9231d4b36f36ccd5cba693a5df5a6324540d58b9 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 5 Jul 2025 01:19:50 +0300 Subject: [PATCH 38/42] SimpleFunctionTerminationCondition => TerminationTest --- src/TerminationConditions.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/TerminationConditions.jl b/src/TerminationConditions.jl index 3ab7fe8..45867c7 100644 --- a/src/TerminationConditions.jl +++ b/src/TerminationConditions.jl @@ -8,8 +8,7 @@ using PsychometricsBazaarBase.ConfigTools: @returnsome, find1_instance using FittedItemBanks import Base: show -export TerminationCondition, - FixedLength, SimpleFunctionTerminationCondition +export TerminationCondition, FixedLength, TerminationTest export RunForever """ @@ -37,12 +36,12 @@ function show(io::IO, ::MIME"text/plain", condition::FixedLength) println(io, "Terminate test after administering $(condition.num_items) items") end -struct SimpleFunctionTerminationCondition{F} <: TerminationCondition - func::F +struct TerminationTest{F} <: TerminationCondition + condition::F end -function (condition::SimpleFunctionTerminationCondition)(responses::TrackedResponses, +function (condition::TerminationTest)(responses::TrackedResponses, items::AbstractItemBank) - condition.func(responses, items) + condition.condition(responses, items) end struct RunForever <: TerminationCondition end From 516550994b168d4ae9a2f6b82ef8e15ed034a1a3 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 5 Jul 2025 01:35:19 +0300 Subject: [PATCH 39/42] More consistent path/module naming --- src/{aggregators => Aggregators}/Aggregators.jl | 0 src/{aggregators => Aggregators}/ability_estimator.jl | 0 src/{aggregators => Aggregators}/ability_tracker.jl | 0 .../ability_trackers/closed_form_normal.jl | 0 src/{aggregators => Aggregators}/ability_trackers/grid.jl | 0 .../ability_trackers/laplace.jl | 0 .../ability_trackers/point.jl | 0 src/{aggregators => Aggregators}/optimizers.jl | 0 src/{aggregators => Aggregators}/riemann.jl | 0 src/{aggregators => Aggregators}/slow.jl | 0 src/{aggregators => Aggregators}/speculators.jl | 0 src/{aggregators => Aggregators}/tracked.jl | 0 src/ComputerAdaptiveTesting.jl | 8 ++++---- src/{decision_tree => DecisionTree}/DecisionTree.jl | 0 src/{decision_tree => DecisionTree}/mmap.jl | 0 src/{decision_tree => DecisionTree}/sim.jl | 0 src/{next_item_rules => NextItemRules}/NextItemRules.jl | 0 .../combinators/expectation.jl | 0 .../combinators/likelihood.jl | 0 .../combinators/scalarizers.jl | 0 .../criteria/item/information.jl | 0 .../criteria/item/urry.jl | 0 .../criteria/pointwise/information.jl | 0 .../criteria/pointwise/information_special.jl | 0 .../criteria/pointwise/information_support.jl | 0 .../criteria/pointwise/kl.jl | 0 .../criteria/state/ability_variance.jl | 0 .../porcelain/aliases.jl | 0 .../porcelain/porcelain.jl | 0 .../prelude/abstract.jl | 0 .../prelude/criteria.jl | 0 .../prelude/next_item_rule.jl | 0 .../prelude/preallocate.jl | 0 .../strategies/balance.jl | 0 .../strategies/exhaustive.jl | 0 .../strategies/pointwise.jl | 0 .../strategies/random.jl | 0 .../strategies/randomesque.jl | 0 .../strategies/sequential.jl | 0 src/{sim => Sim}/Sim.jl | 0 src/{sim => Sim}/loop.jl | 0 src/{sim => Sim}/recorder.jl | 0 src/{sim => Sim}/run.jl | 0 43 files changed, 4 insertions(+), 4 deletions(-) rename src/{aggregators => Aggregators}/Aggregators.jl (100%) rename src/{aggregators => Aggregators}/ability_estimator.jl (100%) rename src/{aggregators => Aggregators}/ability_tracker.jl (100%) rename src/{aggregators => Aggregators}/ability_trackers/closed_form_normal.jl (100%) rename src/{aggregators => Aggregators}/ability_trackers/grid.jl (100%) rename src/{aggregators => Aggregators}/ability_trackers/laplace.jl (100%) rename src/{aggregators => Aggregators}/ability_trackers/point.jl (100%) rename src/{aggregators => Aggregators}/optimizers.jl (100%) rename src/{aggregators => Aggregators}/riemann.jl (100%) rename src/{aggregators => Aggregators}/slow.jl (100%) rename src/{aggregators => Aggregators}/speculators.jl (100%) rename src/{aggregators => Aggregators}/tracked.jl (100%) rename src/{decision_tree => DecisionTree}/DecisionTree.jl (100%) rename src/{decision_tree => DecisionTree}/mmap.jl (100%) rename src/{decision_tree => DecisionTree}/sim.jl (100%) rename src/{next_item_rules => NextItemRules}/NextItemRules.jl (100%) rename src/{next_item_rules => NextItemRules}/combinators/expectation.jl (100%) rename src/{next_item_rules => NextItemRules}/combinators/likelihood.jl (100%) rename src/{next_item_rules => NextItemRules}/combinators/scalarizers.jl (100%) rename src/{next_item_rules => NextItemRules}/criteria/item/information.jl (100%) rename src/{next_item_rules => NextItemRules}/criteria/item/urry.jl (100%) rename src/{next_item_rules => NextItemRules}/criteria/pointwise/information.jl (100%) rename src/{next_item_rules => NextItemRules}/criteria/pointwise/information_special.jl (100%) rename src/{next_item_rules => NextItemRules}/criteria/pointwise/information_support.jl (100%) rename src/{next_item_rules => NextItemRules}/criteria/pointwise/kl.jl (100%) rename src/{next_item_rules => NextItemRules}/criteria/state/ability_variance.jl (100%) rename src/{next_item_rules => NextItemRules}/porcelain/aliases.jl (100%) rename src/{next_item_rules => NextItemRules}/porcelain/porcelain.jl (100%) rename src/{next_item_rules => NextItemRules}/prelude/abstract.jl (100%) rename src/{next_item_rules => NextItemRules}/prelude/criteria.jl (100%) rename src/{next_item_rules => NextItemRules}/prelude/next_item_rule.jl (100%) rename src/{next_item_rules => NextItemRules}/prelude/preallocate.jl (100%) rename src/{next_item_rules => NextItemRules}/strategies/balance.jl (100%) rename src/{next_item_rules => NextItemRules}/strategies/exhaustive.jl (100%) rename src/{next_item_rules => NextItemRules}/strategies/pointwise.jl (100%) rename src/{next_item_rules => NextItemRules}/strategies/random.jl (100%) rename src/{next_item_rules => NextItemRules}/strategies/randomesque.jl (100%) rename src/{next_item_rules => NextItemRules}/strategies/sequential.jl (100%) rename src/{sim => Sim}/Sim.jl (100%) rename src/{sim => Sim}/loop.jl (100%) rename src/{sim => Sim}/recorder.jl (100%) rename src/{sim => Sim}/run.jl (100%) diff --git a/src/aggregators/Aggregators.jl b/src/Aggregators/Aggregators.jl similarity index 100% rename from src/aggregators/Aggregators.jl rename to src/Aggregators/Aggregators.jl diff --git a/src/aggregators/ability_estimator.jl b/src/Aggregators/ability_estimator.jl similarity index 100% rename from src/aggregators/ability_estimator.jl rename to src/Aggregators/ability_estimator.jl diff --git a/src/aggregators/ability_tracker.jl b/src/Aggregators/ability_tracker.jl similarity index 100% rename from src/aggregators/ability_tracker.jl rename to src/Aggregators/ability_tracker.jl diff --git a/src/aggregators/ability_trackers/closed_form_normal.jl b/src/Aggregators/ability_trackers/closed_form_normal.jl similarity index 100% rename from src/aggregators/ability_trackers/closed_form_normal.jl rename to src/Aggregators/ability_trackers/closed_form_normal.jl diff --git a/src/aggregators/ability_trackers/grid.jl b/src/Aggregators/ability_trackers/grid.jl similarity index 100% rename from src/aggregators/ability_trackers/grid.jl rename to src/Aggregators/ability_trackers/grid.jl diff --git a/src/aggregators/ability_trackers/laplace.jl b/src/Aggregators/ability_trackers/laplace.jl similarity index 100% rename from src/aggregators/ability_trackers/laplace.jl rename to src/Aggregators/ability_trackers/laplace.jl diff --git a/src/aggregators/ability_trackers/point.jl b/src/Aggregators/ability_trackers/point.jl similarity index 100% rename from src/aggregators/ability_trackers/point.jl rename to src/Aggregators/ability_trackers/point.jl diff --git a/src/aggregators/optimizers.jl b/src/Aggregators/optimizers.jl similarity index 100% rename from src/aggregators/optimizers.jl rename to src/Aggregators/optimizers.jl diff --git a/src/aggregators/riemann.jl b/src/Aggregators/riemann.jl similarity index 100% rename from src/aggregators/riemann.jl rename to src/Aggregators/riemann.jl diff --git a/src/aggregators/slow.jl b/src/Aggregators/slow.jl similarity index 100% rename from src/aggregators/slow.jl rename to src/Aggregators/slow.jl diff --git a/src/aggregators/speculators.jl b/src/Aggregators/speculators.jl similarity index 100% rename from src/aggregators/speculators.jl rename to src/Aggregators/speculators.jl diff --git a/src/aggregators/tracked.jl b/src/Aggregators/tracked.jl similarity index 100% rename from src/aggregators/tracked.jl rename to src/Aggregators/tracked.jl diff --git a/src/ComputerAdaptiveTesting.jl b/src/ComputerAdaptiveTesting.jl index 033ba97..ec81fb5 100644 --- a/src/ComputerAdaptiveTesting.jl +++ b/src/ComputerAdaptiveTesting.jl @@ -23,19 +23,19 @@ include("./ConfigBase.jl") include("./Responses.jl") # Near base -include("./aggregators/Aggregators.jl") +include("./Aggregators/Aggregators.jl") # Extra item banks include("./logitembank.jl") # Stages -include("./next_item_rules/NextItemRules.jl") +include("./NextItemRules/NextItemRules.jl") include("./TerminationConditions.jl") # Combining / running include("./Rules.jl") -include("./sim/Sim.jl") -include("./decision_tree/DecisionTree.jl") +include("./Sim/Sim.jl") +include("./DecisionTree/DecisionTree.jl") # Stateful layer, compat, and comparison include("./Stateful.jl") diff --git a/src/decision_tree/DecisionTree.jl b/src/DecisionTree/DecisionTree.jl similarity index 100% rename from src/decision_tree/DecisionTree.jl rename to src/DecisionTree/DecisionTree.jl diff --git a/src/decision_tree/mmap.jl b/src/DecisionTree/mmap.jl similarity index 100% rename from src/decision_tree/mmap.jl rename to src/DecisionTree/mmap.jl diff --git a/src/decision_tree/sim.jl b/src/DecisionTree/sim.jl similarity index 100% rename from src/decision_tree/sim.jl rename to src/DecisionTree/sim.jl diff --git a/src/next_item_rules/NextItemRules.jl b/src/NextItemRules/NextItemRules.jl similarity index 100% rename from src/next_item_rules/NextItemRules.jl rename to src/NextItemRules/NextItemRules.jl diff --git a/src/next_item_rules/combinators/expectation.jl b/src/NextItemRules/combinators/expectation.jl similarity index 100% rename from src/next_item_rules/combinators/expectation.jl rename to src/NextItemRules/combinators/expectation.jl diff --git a/src/next_item_rules/combinators/likelihood.jl b/src/NextItemRules/combinators/likelihood.jl similarity index 100% rename from src/next_item_rules/combinators/likelihood.jl rename to src/NextItemRules/combinators/likelihood.jl diff --git a/src/next_item_rules/combinators/scalarizers.jl b/src/NextItemRules/combinators/scalarizers.jl similarity index 100% rename from src/next_item_rules/combinators/scalarizers.jl rename to src/NextItemRules/combinators/scalarizers.jl diff --git a/src/next_item_rules/criteria/item/information.jl b/src/NextItemRules/criteria/item/information.jl similarity index 100% rename from src/next_item_rules/criteria/item/information.jl rename to src/NextItemRules/criteria/item/information.jl diff --git a/src/next_item_rules/criteria/item/urry.jl b/src/NextItemRules/criteria/item/urry.jl similarity index 100% rename from src/next_item_rules/criteria/item/urry.jl rename to src/NextItemRules/criteria/item/urry.jl diff --git a/src/next_item_rules/criteria/pointwise/information.jl b/src/NextItemRules/criteria/pointwise/information.jl similarity index 100% rename from src/next_item_rules/criteria/pointwise/information.jl rename to src/NextItemRules/criteria/pointwise/information.jl diff --git a/src/next_item_rules/criteria/pointwise/information_special.jl b/src/NextItemRules/criteria/pointwise/information_special.jl similarity index 100% rename from src/next_item_rules/criteria/pointwise/information_special.jl rename to src/NextItemRules/criteria/pointwise/information_special.jl diff --git a/src/next_item_rules/criteria/pointwise/information_support.jl b/src/NextItemRules/criteria/pointwise/information_support.jl similarity index 100% rename from src/next_item_rules/criteria/pointwise/information_support.jl rename to src/NextItemRules/criteria/pointwise/information_support.jl diff --git a/src/next_item_rules/criteria/pointwise/kl.jl b/src/NextItemRules/criteria/pointwise/kl.jl similarity index 100% rename from src/next_item_rules/criteria/pointwise/kl.jl rename to src/NextItemRules/criteria/pointwise/kl.jl diff --git a/src/next_item_rules/criteria/state/ability_variance.jl b/src/NextItemRules/criteria/state/ability_variance.jl similarity index 100% rename from src/next_item_rules/criteria/state/ability_variance.jl rename to src/NextItemRules/criteria/state/ability_variance.jl diff --git a/src/next_item_rules/porcelain/aliases.jl b/src/NextItemRules/porcelain/aliases.jl similarity index 100% rename from src/next_item_rules/porcelain/aliases.jl rename to src/NextItemRules/porcelain/aliases.jl diff --git a/src/next_item_rules/porcelain/porcelain.jl b/src/NextItemRules/porcelain/porcelain.jl similarity index 100% rename from src/next_item_rules/porcelain/porcelain.jl rename to src/NextItemRules/porcelain/porcelain.jl diff --git a/src/next_item_rules/prelude/abstract.jl b/src/NextItemRules/prelude/abstract.jl similarity index 100% rename from src/next_item_rules/prelude/abstract.jl rename to src/NextItemRules/prelude/abstract.jl diff --git a/src/next_item_rules/prelude/criteria.jl b/src/NextItemRules/prelude/criteria.jl similarity index 100% rename from src/next_item_rules/prelude/criteria.jl rename to src/NextItemRules/prelude/criteria.jl diff --git a/src/next_item_rules/prelude/next_item_rule.jl b/src/NextItemRules/prelude/next_item_rule.jl similarity index 100% rename from src/next_item_rules/prelude/next_item_rule.jl rename to src/NextItemRules/prelude/next_item_rule.jl diff --git a/src/next_item_rules/prelude/preallocate.jl b/src/NextItemRules/prelude/preallocate.jl similarity index 100% rename from src/next_item_rules/prelude/preallocate.jl rename to src/NextItemRules/prelude/preallocate.jl diff --git a/src/next_item_rules/strategies/balance.jl b/src/NextItemRules/strategies/balance.jl similarity index 100% rename from src/next_item_rules/strategies/balance.jl rename to src/NextItemRules/strategies/balance.jl diff --git a/src/next_item_rules/strategies/exhaustive.jl b/src/NextItemRules/strategies/exhaustive.jl similarity index 100% rename from src/next_item_rules/strategies/exhaustive.jl rename to src/NextItemRules/strategies/exhaustive.jl diff --git a/src/next_item_rules/strategies/pointwise.jl b/src/NextItemRules/strategies/pointwise.jl similarity index 100% rename from src/next_item_rules/strategies/pointwise.jl rename to src/NextItemRules/strategies/pointwise.jl diff --git a/src/next_item_rules/strategies/random.jl b/src/NextItemRules/strategies/random.jl similarity index 100% rename from src/next_item_rules/strategies/random.jl rename to src/NextItemRules/strategies/random.jl diff --git a/src/next_item_rules/strategies/randomesque.jl b/src/NextItemRules/strategies/randomesque.jl similarity index 100% rename from src/next_item_rules/strategies/randomesque.jl rename to src/NextItemRules/strategies/randomesque.jl diff --git a/src/next_item_rules/strategies/sequential.jl b/src/NextItemRules/strategies/sequential.jl similarity index 100% rename from src/next_item_rules/strategies/sequential.jl rename to src/NextItemRules/strategies/sequential.jl diff --git a/src/sim/Sim.jl b/src/Sim/Sim.jl similarity index 100% rename from src/sim/Sim.jl rename to src/Sim/Sim.jl diff --git a/src/sim/loop.jl b/src/Sim/loop.jl similarity index 100% rename from src/sim/loop.jl rename to src/Sim/loop.jl diff --git a/src/sim/recorder.jl b/src/Sim/recorder.jl similarity index 100% rename from src/sim/recorder.jl rename to src/Sim/recorder.jl diff --git a/src/sim/run.jl b/src/Sim/run.jl similarity index 100% rename from src/sim/run.jl rename to src/Sim/run.jl From 921df967649135f53348cedf2fd7101a2942590b Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 5 Jul 2025 01:53:59 +0300 Subject: [PATCH 40/42] Remove PushVectors --- src/ComputerAdaptiveTesting.jl | 3 - src/DecisionTree/DecisionTree.jl | 13 ++-- src/vendor/PushVectors.jl | 102 ------------------------------- 3 files changed, 8 insertions(+), 110 deletions(-) delete mode 100644 src/vendor/PushVectors.jl diff --git a/src/ComputerAdaptiveTesting.jl b/src/ComputerAdaptiveTesting.jl index ec81fb5..f1228e0 100644 --- a/src/ComputerAdaptiveTesting.jl +++ b/src/ComputerAdaptiveTesting.jl @@ -13,9 +13,6 @@ export Stateful, Comparison # Extension modules public require_testext -# Vendored dependencies -include("./vendor/PushVectors.jl") - # Config base include("./ConfigBase.jl") diff --git a/src/DecisionTree/DecisionTree.jl b/src/DecisionTree/DecisionTree.jl index b42ad58..97e354c 100644 --- a/src/DecisionTree/DecisionTree.jl +++ b/src/DecisionTree/DecisionTree.jl @@ -3,7 +3,6 @@ module DecisionTree using Mmap: mmap using ComputerAdaptiveTesting.ConfigBase: CatConfigBase -using ComputerAdaptiveTesting.PushVectors using ComputerAdaptiveTesting.NextItemRules using ComputerAdaptiveTesting.Aggregators using ComputerAdaptiveTesting.Responses: BareResponses, Response, add_response!, pop_response! @@ -18,15 +17,19 @@ end Base.@kwdef mutable struct TreePosition max_depth::UInt cur_depth::UInt - todo::PushVector{AgendaItem, Vector{AgendaItem}} + todo::Vector{AgendaItem} parent_ability::Float64 end function TreePosition(max_depth) - TreePosition(max_depth = max_depth, + todo = Vector{AgendaItem}() + sizehint!(todo, max_depth) + TreePosition(; + max_depth, cur_depth = 0, - todo = PushVector{AgendaItem}(max_depth), - parent_ability = 0.0) + todo, + parent_ability = 0.0 + ) end function next!(state::TreePosition, responses, item_bank, question, ability) diff --git a/src/vendor/PushVectors.jl b/src/vendor/PushVectors.jl deleted file mode 100644 index 9949a43..0000000 --- a/src/vendor/PushVectors.jl +++ /dev/null @@ -1,102 +0,0 @@ -module PushVectors - -export PushVector, finish! - -mutable struct PushVector{T, V <: AbstractVector{T}} <: AbstractVector{T} - "Vector used for storage." - parent::V - "Number of elements held by `parent`." - len::Int -end - -""" - PushVector{T}(sizehint = 4) - -Create a `PushVector` for elements typed `T`, with no initial elements. `sizehint` -determines the initial allocated size. -""" -function PushVector{T}(sizehint::Integer = 4) where {T} - sizehint ≥ 0 || throw(DomainError(sizehint, "Invalid initial size.")) - PushVector(Vector{T}(undef, sizehint), 0) -end - -@inline Base.length(v::PushVector) = v.len - -@inline Base.size(v::PushVector) = (v.len,) - -function Base.sizehint!(v::PushVector, n) - if length(v.parent) < n || n ≥ v.len - resize!(v.parent, n) - end - nothing -end - -@inline function Base.getindex(v::PushVector, i) - @boundscheck checkbounds(v, i) - @inbounds v.parent[i] -end - -@inline function Base.setindex!(v::PushVector, x, i) - @boundscheck checkbounds(v, i) - @inbounds v.parent[i] = x -end - -function Base.push!(v::PushVector, x) - v.len += 1 - if v.len > length(v.parent) - resize!(v.parent, v.len * 2) - end - v.parent[v.len] = x - v -end - -function Base.pop!(v::PushVector) - isempty(v) && throw(ArgumentError("vector must be non-empty")) - x = v.parent[v.len] - v.len -= 1 - x -end - -function Base.resize!(v::PushVector, n) - if n < v.len - v.len = n - elseif n > v.len - if n > length(v.parent) - resize!(v.parent, n) - end - v.len = n - end - v -end - -Base.empty!(v::PushVector) = (v.len = 0; v) - -function Base.append!(v::PushVector, xs) - ι_xs = eachindex(xs) # allow generalized indexing - l = length(ι_xs) - if l ≤ 4 - for x in xs - push!(v, x) - end - else - L = l + v.len - if length(v.parent) < L - resize!(v.parent, nextpow(2, nextpow(2, L))) - end - @inbounds copyto!(v.parent, v.len + 1, xs, first(ι_xs), l) - v.len += l - end - v -end - -""" - finish!(v) - -Shrink the buffer `v` to its current content and return that vector. - -!!! NOTE - Consequences are undefined if you modify `v` after this. -""" -finish!(v::PushVector) = resize!(v.parent, v.len) - -end # module From 1382bded4836825976cbc48452800b4867abafa0 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 5 Jul 2025 01:55:35 +0300 Subject: [PATCH 41/42] Version bump => 0.4.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3c91918..4e24eff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComputerAdaptiveTesting" uuid = "5a0d4f34-1f62-4a66-80fe-87aba0485488" authors = ["Frankie Robertson"] -version = "0.3.2" +version = "0.4.0" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" From 3237b057216411928723f1c913a2ecc689c18a53 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 5 Jul 2025 02:08:30 +0300 Subject: [PATCH 42/42] Fixup benchmark script --- benchmark/benchmarks.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index e10f2f9..5466a4e 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -8,7 +8,6 @@ using FittedItemBanks.DummyData: dummy_full, SimpleItemBankSpec, StdModel4PL using ComputerAdaptiveTesting.Aggregators using PsychometricsBazaarBase.Optimizers using PsychometricsBazaarBase.Integrators: even_grid -using ComputerAdaptiveTesting.NextItemRules: mirtcat_quadpts using ComputerAdaptiveTesting.NextItemRules: ExpectationBasedItemCriterion, PointResponseExpectation using ComputerAdaptiveTesting.NextItemRules @@ -27,7 +26,7 @@ function prepare_4pls(group) num_questions = 20, num_testees = 1 ) - integrator = even_grid(-6.0, 6.0, mirtcat_quadpts(1)) + integrator = even_grid(-6.0, 6.0, 121) optimizer = AbilityOptimizer(OneDimOptimOptimizer(-6.0, 6.0, NelderMead())) dist_ability_estimator = PosteriorAbilityEstimator()