From b1baf3709b57678b8c3afdad060acf7d3e74475a Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Wed, 25 Dec 2024 12:49:21 +0200 Subject: [PATCH 1/8] Remove KernelAbstractions implementation --- Project.toml | 2 - src/next_item_rules/NextItemRules.jl | 3 - src/next_item_rules/experimental/ka.jl | 206 ------------------------- 3 files changed, 211 deletions(-) delete mode 100644 src/next_item_rules/experimental/ka.jl diff --git a/Project.toml b/Project.toml index ccfb65a..62fb79f 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,6 @@ FittedItemBanks = "3f797b09-34e4-41d7-acf6-3302ae3248a5" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" @@ -48,7 +47,6 @@ FittedItemBanks = "^0.6.3" ForwardDiff = "0.10.24" HypothesisTests = "^0.10.12, ^0.11.0" Interpolations = "^0.14, ^0.15" -KernelAbstractions = "^0.9.22" Lazy = "0.15" LogarithmicNumbers = "1" MacroTools = "^0.5.6" diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index acf69f2..fa76091 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -72,7 +72,4 @@ include("./criteria/state/ability_variance.jl") # Porcelain include("./porcelain/aliases.jl") -# Experimental -include("./experimental/ka.jl") - end diff --git a/src/next_item_rules/experimental/ka.jl b/src/next_item_rules/experimental/ka.jl deleted file mode 100644 index e2a2a4e..0000000 --- a/src/next_item_rules/experimental/ka.jl +++ /dev/null @@ -1,206 +0,0 @@ -import ComputerAdaptiveTesting.NextItemRules -using ComputerAdaptiveTesting.ItemBanks -using FittedItemBanks: item_bank_xs -using KernelAbstractions -using KernelAbstractions.Extras.LoopInfo: @unroll - -""" -$(TYPEDEF) -$(TYPEDFIELDS) -""" -struct KernelAbstractionsExhaustiveSearchConfig{ArgsT} <: NextItemStrategy - kernel_args::ArgsT - - function KernelAbstractionsExhaustiveSearchConfig(kwargs...) - return new{typeof(kwargs)}(kwargs) - end -end - -""" -$(TYPEDEF) -$(TYPEDFIELDS) -""" -struct KernelAbstractionsExhaustiveSearch{KernelT} <: NextItemStrategy - kernel::KernelT -end - -function NextItemRules.preallocate(config::KernelAbstractionsExhaustiveSearchConfig) - return KernelAbstractionsExhaustiveSearch(gridded_point_expected_posterior_variance_kernel_simple(config.kernel_args...)) -end - -function expected_response(ability, xs, ys) - start = xs.start - stop = xs.stop - lendiv = xs.lendiv - ability_index = (ability - start) * (lendiv / (stop - start)) - if isnan(ability_index) - KernelAbstractions.@print("ability ", - ability, "\tstart ", start, "\tstop ", stop, "\tlendiv ", - lendiv, "\tability_index", ability_index, "\n") - end - index_floor = floor(Int, ability_index) - index_ceil = index_floor + 1 - if index_floor < 1 - return ys[1] - elseif index_ceil > length(ys) - return ys[end] - else - # Linear interpolation - return ys[index_floor] + - (ys[index_ceil] - ys[index_floor]) * (ability_index - index_floor) - end -end - -function var(x, w) - #KernelAbstractions.@print("x ", x, "\tw ", w, "\n") - mean = zero(eltype(x)) - norm = zero(eltype(x)) - #@unroll - for i in eachindex(x) - #KernelAbstractions.@print("i ", i, "\tw[i] ", w[i], "\tx[i]: ", x[i], "\n") - mean += w[i] * x[i] - norm += w[i] - end - if norm == zero(eltype(x)) - return typemax(eltype(x)) - end - mean /= norm - var = zero(eltype(x)) - #@unroll - for i in eachindex(x) - dev = (x[i] - mean) - var += w[i] * dev * dev - end - return var / norm -end - -# This kernel is parallel across items -# It is not tiled -@kernel inbounds=true function gridded_point_expected_posterior_variance_kernel_simple( - # The xs from the item bank as a LinRange - @Const(in_gridded_item_bank_xs), - # All ys from the item bank - @Const(in_gridded_item_bank_ys), - # The evaluated likelihood at the integration points - @Const(in_likelihood_points), - # The current point estimate of the ability - @Const(ability_estimate), - # The resulting expected posterior variance array - out_epv -) - lh_eltype = eltype(in_likelihood_points) - grid_size = length(in_gridded_item_bank_xs) - lh_buf = @private lh_eltype grid_size - #@private item_index - item_index = @index(Global) - item_ys = @view in_gridded_item_bank_ys[:, item_index] - # Step 1: Compute the expected response for the current item - #KernelAbstractions.@print("ability_estimate", ability_estimate, "\tin_gridded_item_bank_xs ", in_gridded_item_bank_xs, "item_ys", item_ys, "\n") - response_expectation = expected_response( - ability_estimate, in_gridded_item_bank_xs, item_ys) - # Step 2: Get the variance in the positive response case - lh_buf .= in_likelihood_points .* item_ys - #KernelAbstractions.@print("First lh buf") - #KernelAbstractions.@print("in_gridded_item_bank_xs ", size(in_gridded_item_bank_xs), "\tlh_buf ", size(lh_buf), "\n") - #KernelAbstractions.@print("\n\n\n") - KernelAbstractions.@print("in_gridded_item_bank_xs ", - in_gridded_item_bank_xs, "\tlh_buf ", lh_buf, "\n") - pos_exp_var = @inline var(in_gridded_item_bank_xs, lh_buf) - # Step 3: Get the variance in the negative response case - lh_buf .= in_likelihood_points .* (one(eltype(item_ys)) .- item_ys) - #KernelAbstractions.@print("Second lh buf") - #KernelAbstractions.@print("in_gridded_item_bank_xs ", size(in_gridded_item_bank_xs), "\tlh_buf ", size(lh_buf), "\n") - #KernelAbstractions.@print("\n\n\n") - neg_exp_var = @inline var(in_gridded_item_bank_xs, lh_buf) - # Step 4: Combine the variances using the response expectation - #KernelAbstractions.@print("item_index ", item_index, "\tresponse_expectation ", response_expectation, "\tneg_exp_var ", neg_exp_var, "\tpos_exp_var ", pos_exp_var, "\n") - if isinf(pos_exp_var) || isinf(neg_exp_var) - out_epv[item_index] = typemax(eltype(response_expectation)) - else - negative_response_expectation = one(eltype(response_expectation)) - - response_expectation - out_epv[item_index] = negative_response_expectation * neg_exp_var + - response_expectation * pos_exp_var - end -end - -function ( - rule_config::RuleConfigT where { - RuleConfigT <: ItemStrategyNextItemRule{ - <:KernelAbstractionsExhaustiveSearchConfig, - <:ExpectationBasedItemCriterion{ - <:PointResponseExpectation, <:AbilityVarianceStateCriterion} -} -} -)(responses, items::DichotomousPointsWithLogsItemBank) - return preallocate(rule_config)(responses, items) -end - -function move(backend, input) - # TODO replace with adapt(backend, input) - #out = KernelAbstractions.allocate(backend, eltype(input), size(input)) - out = KernelAbstractions.allocate(backend, Float32, size(input)) - return KernelAbstractions.copyto!(backend, out, input) -end - -function linrange_to_float32(input) - return LinRange(Float32(input.start), Float32(input.stop), input.len) -end - -function ( - rule::RuleT where { - RuleT <: ItemStrategyNextItemRule{ - <:KernelAbstractionsExhaustiveSearch, - <:ExpectationBasedItemCriterion{ - <:PointResponseExpectation, <:AbilityVarianceStateCriterion} -} -} -)(tracked_responses, items::DichotomousPointsWithLogsItemBank{}) - backend = rule.strategy.kernel.backend - responses = tracked_responses.responses - #=exp_resp = Aggregators.response_expectation( - rule, - tracked_responses, - item_idx - )=# - - @info "responses" responses.indices responses.values - #for item_index in responses.indices - #@info "ys" items.inner_bank.ys[:, item_index] - #end - ability_estimate = rule.criterion.ability_estimator(tracked_responses) - @info "ability_estimate" rule.criterion.ability_estimator ability_estimate - in_gridded_item_bank_xs = item_bank_xs(items) - in_gridded_item_bank_ys = items.inner_bank.ys - in_gridded_item_bank_log_ys = items.log_ys - (num_quadrature_points, num_items) = size(in_gridded_item_bank_ys) - # TODO: This could be handled by TrackedResponses - log_likelihood_points = reduce(.+, - ( - @view in_gridded_item_bank_log_ys[Int(resp_value) + 1, :, resp_idx] - for (resp_idx, resp_value) - in zip(responses.indices, responses.values) - ); - init = zeros(eltype(in_gridded_item_bank_log_ys), num_quadrature_points) - ) - c = maximum(log_likelihood_points) - log_likelihood_points .-= c - # TODO: Keep this as logs - log_likelihood_points = exp.(log_likelihood_points) - @info "xs" in_gridded_item_bank_xs - in_gridded_item_bank_xs = linrange_to_float32(in_gridded_item_bank_xs) - in_gridded_item_bank_ys = move(backend, in_gridded_item_bank_ys) - log_likelihood_points = move(backend, log_likelihood_points) - out_epv = KernelAbstractions.zeros(backend, eltype(in_gridded_item_bank_ys), num_items) - rule.strategy.kernel( - in_gridded_item_bank_xs, - in_gridded_item_bank_ys, - log_likelihood_points, - ability_estimate, - out_epv - ) - synchronize(backend) - out_epv[responses.indices] .= typemax(eltype(out_epv)) - @info "out_epv" out_epv - return argmin(out_epv) -end From 25df36498bc7adf7d7578fbc499ef16b279f5297 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Wed, 25 Dec 2024 12:52:02 +0200 Subject: [PATCH 2/8] Fix typo in criteria --- src/next_item_rules/prelude/criteria.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/next_item_rules/prelude/criteria.jl b/src/next_item_rules/prelude/criteria.jl index 118a76e..f95f247 100644 --- a/src/next_item_rules/prelude/criteria.jl +++ b/src/next_item_rules/prelude/criteria.jl @@ -46,7 +46,7 @@ function compute_criteria( items::AbstractItemBank ) where {ItemCriterionT <: ItemCriterion, TrackedResponseT <: TrackedResponses} objective_state = init_thread(criterion, responses) - return [criterion(objective_state, responses, item_idx) + return [compute_criterion(criterion, objective_state, responses, item_idx) for item_idx in eachindex(items)] end From ea8b4e0635f8e0b2be7452e051c60025b95ede94 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Thu, 9 Jan 2025 09:39:07 +0200 Subject: [PATCH 3/8] Remove MultiAbilityTracker --- src/aggregators/Aggregators.jl | 2 +- src/aggregators/ability_tracker.jl | 1 - src/aggregators/ability_trackers/multi.jl | 9 --------- 3 files changed, 1 insertion(+), 11 deletions(-) delete mode 100644 src/aggregators/ability_trackers/multi.jl diff --git a/src/aggregators/Aggregators.jl b/src/aggregators/Aggregators.jl index 2c374fb..663af7b 100644 --- a/src/aggregators/Aggregators.jl +++ b/src/aggregators/Aggregators.jl @@ -26,7 +26,7 @@ import PsychometricsBazaarBase.IntegralCoeffs export AbilityEstimator, TrackedResponses export AbilityTracker, NullAbilityTracker, PointAbilityTracker, GriddedAbilityTracker -export ClosedFormNormalAbilityTracker, MultiAbilityTracker, track! +export ClosedFormNormalAbilityTracker, track! export response_expectation, add_response!, pop_response!, expectation, distribution_estimator export PointAbilityEstimator, PriorAbilityEstimator, LikelihoodAbilityEstimator diff --git a/src/aggregators/ability_tracker.jl b/src/aggregators/ability_tracker.jl index 05c678f..6605bbf 100644 --- a/src/aggregators/ability_tracker.jl +++ b/src/aggregators/ability_tracker.jl @@ -80,7 +80,6 @@ include("./ability_trackers/grid.jl") include("./ability_trackers/point.jl") include("./ability_trackers/closed_form_normal.jl") include("./ability_trackers/laplace.jl") -include("./ability_trackers/multi.jl") """ This method returns a tracked point estimate if it is has the given ability diff --git a/src/aggregators/ability_trackers/multi.jl b/src/aggregators/ability_trackers/multi.jl deleted file mode 100644 index d2407de..0000000 --- a/src/aggregators/ability_trackers/multi.jl +++ /dev/null @@ -1,9 +0,0 @@ -mutable struct MultiAbilityTracker <: AbilityTracker - trackers::Vector{AbilityTracker} -end - -function track!(responses, ability_tracker::MultiAbilityTracker) - for tracker in ability_tracker.trackers - track!(responses, tracker) - end -end From 3eb97a81c07ec101a27cf2faea124cb739222ee4 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Thu, 9 Jan 2025 10:08:15 +0200 Subject: [PATCH 4/8] Refactor compatible_tracker a bit --- src/aggregators/Aggregators.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/aggregators/Aggregators.jl b/src/aggregators/Aggregators.jl index 663af7b..7313fb5 100644 --- a/src/aggregators/Aggregators.jl +++ b/src/aggregators/Aggregators.jl @@ -91,12 +91,16 @@ function AbilityTracker(bits...; integrator = nothing, ability_estimator = nothi end end -function compatible_tracker(bits...; integrator, ability_estimator, prefer_tracked) - ability_tracker = AbilityTracker(bits...; ability_estimator = ability_estimator) - if ability_tracker isa GriddedAbilityTracker && +function find_ability_tracker(ability_tracker, typ, integrator) + if ability_tracker isa typ && ability_tracker.integrator === integrator return ability_tracker end +end + +function compatible_tracker(bits...; integrator, ability_estimator, prefer_tracked) + ability_tracker = AbilityTracker(bits...; ability_estimator = ability_estimator) + @returnsome find_ability_tracker(ability_tracker, GriddedAbilityTracker, integrator) if prefer_tracked return AbilityTracker(bits...; integrator = integrator, From a2a24091c40e83da91840b03d26edd893ad346ca Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Mon, 10 Feb 2025 18:04:12 +0200 Subject: [PATCH 5/8] Fix up GriddedAbilityTracker --- src/aggregators/ability_trackers/grid.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/aggregators/ability_trackers/grid.jl b/src/aggregators/ability_trackers/grid.jl index cd166fb..2b996db 100644 --- a/src/aggregators/ability_trackers/grid.jl +++ b/src/aggregators/ability_trackers/grid.jl @@ -9,10 +9,14 @@ end function GriddedAbilityTracker(ability_estimator::DistributionAbilityEstimator, integrator::FixedGridIntegrator) - GriddedAbilityTracker(ability_estimator, integrator, fill(NaN, length(integrator.grid))) + GriddedAbilityTracker(ability_estimator, integrator, fill(1.0, length(integrator.grid))) end +find_grid(integrator::FixedGridIntegrator) = integrator.grid +find_grid(integrator::PreallocatedFixedGridIntegrator) = integrator.inner.grid + function track!(responses, ability_tracker::GriddedAbilityTracker) ability_pdf = pdf(ability_tracker.ability_estimator, responses) - ability_tracker.cur_ability .= ability_pdf.(ability_tracker.integrator.grid) + grid = find_grid(ability_tracker.integrator) + ability_tracker.cur_ability .= ability_pdf.(grid) end From 1998aa09aa22316acd67972affb61359f6214002 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Mon, 24 Mar 2025 14:50:40 +0200 Subject: [PATCH 6/8] Improve Comparison including IncreaseItemBankSizeExecutionStrategy --- src/Comparison.jl | 116 ++++++++++++++++++++++++++++++++++------------ src/Stateful.jl | 18 ++++--- 2 files changed, 99 insertions(+), 35 deletions(-) diff --git a/src/Comparison.jl b/src/Comparison.jl index d8716cf..f8cec75 100644 --- a/src/Comparison.jl +++ b/src/Comparison.jl @@ -4,7 +4,7 @@ module Comparison # Should be kept in mind and kept distinct or code reuse using StatsBase -using FittedItemBanks: AbstractItemBank, ResponseType +using FittedItemBanks: AbstractItemBank, ResponseType, subset using ..Responses using ..CatConfig: CatLoopConfig, CatRules using ..Aggregators: TrackedResponses, add_response!, Speculator, Aggregators, track!, @@ -14,11 +14,11 @@ using Base: Iterators using HypothesisTests using EffectSizes -using DataFrames +using DataFrames: DataFrame using ComputerAdaptiveTesting: Stateful export run_random_comparison, run_comparison -export CatComparisonExecutionStrategy#, IncreaseItemBankSizeExecutionStrategy +export CatComparisonExecutionStrategy, IncreaseItemBankSizeExecutionStrategy #export FollowOneExecutionStrategy, RunIndependentlyExecutionStrategy #export DecisionTreeExecutionStrategy export ReplayResponsesExecutionStrategy @@ -83,7 +83,8 @@ end abstract type CatComparisonExecutionStrategy end -Base.@kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrategy} +struct CatComparisonConfig{ + StrategyT <: CatComparisonExecutionStrategy, PhasesT <: NamedTuple} """ A named tuple with the (named) CatRules (or compatable) to be compared """ @@ -99,13 +100,42 @@ Base.@kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrate measurements::Vector{} =# """ - Which phases to run and/or call the callback on + The phases to run, optionally paired with a callback """ - phases::Set{Symbol} = Set((:before_next_item, :after_next_item)) - """ - The callback which should take a named tuple with information at different phases - """ - callback::Any + phases::PhasesT +end + +""" + CatComparisonConfig(; + rules::NamedTuple{Symbol, StatefulCat}, + strategy::CatComparisonExecutionStrategy, + phases::Union{NamedTuple{Symbol, Callable}, Tuple{Symbol}}, + callback::Callable + ) -> CatComparisonConfig + +CatComparisonConfig sets up a evaluation-oriented comparison between different CAT systems. + +Specify the comparison by listing: CAT systems in `rules`, a `NamedTuple` which gives +identifiers to implementations of the `StatefulCat` interface; the `strategy` to use, +an implementation of `CatComparisonExecutionStrategy`; the `phases` to run listed as +either as a `NamedTuple` with names of phases and corresponding callbacks or `nothing` a +`Tuple` of phases to run; and a `callback` which will be used as a fallback in cases where +no callback is provided. + +The exact phases depend on the strategy used. See their individual documentation for more. +""" +function CatComparisonConfig(; rules, strategy, phases = nothing, callback = nothing) + if callback === nothing + callback = (info; kwargs...) -> nothing + end + if phases === nothing + phases = (:before_next_item, :after_next_item) + end + # TODO: normalize phases into named tuple + if !(phases isa NamedTuple) + phases = NamedTuple((phase => callback for phase in phases)) + end + CatComparisonConfig(rules, strategy, phases) end # Comparison scenarios: @@ -129,9 +159,11 @@ end #phase_func=nothing; function measure_all(comparison, system, cat, phase; kwargs...) - if !(phase in comparison.phases) + @info "measure_all" phase comparison.phases + if !(phase in keys(comparison.phases)) return end + callback = comparison.phases[phase] strategy = comparison.strategy #=measurement_results = [] for measurement in comparison.measurements @@ -145,7 +177,7 @@ function measure_all(comparison, system, cat, phase; kwargs...) #end push!(measurement_results, result) end=# - comparison.callback((; + callback((; phase, system, cat, @@ -158,30 +190,56 @@ struct IncreaseItemBankSizeExecutionStrategy <: CatComparisonExecutionStrategy item_bank::AbstractItemBank sizes::AbstractVector{Int} starting_responses::Int + shuffle::Bool + time_limit::Float64 + + function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, args...) + if any((size > length(item_bank) for size in sizes)) + error("IncreaseItemBankSizeExecutionStrategy: No subset size can be greater than the number of items available in the item bank") + end + new(item_bank, sizes, args...) + end end function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes) - return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0) + return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0, false, Inf) end -function run_comparison(strategy::IncreaseItemBankSizeExecutionStrategy, config) +function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExecutionStrategy}) + strategy = comparison.strategy + current_cats = collect(pairs(comparison.rules)) + next_current_cats = copy(current_cats) + @info "sizes" strategy.sizes for size in strategy.sizes - subsetted_item_bank = subset(strategy.item_bank, size) - responses = TrackedResponses( - BareResponses(ResponseType(strategy.item_bank)), - subsetted_item_bank, - config.ability_tracker - ) - for _ in 1:(strategy.starting_responses) - next_item = config.next_item(responses, subsetted_item_bank) - add_response!(responses, - Response(ResponseType(subsetted_item_bank), next_item, rand(Bool))) + subsetted_item_bank = subset(strategy.item_bank, 1:size) + empty!(next_current_cats) + for (name, cat) in current_cats + Stateful.set_item_bank!(cat, subsetted_item_bank) + for _ in 1:(strategy.starting_responses) + Stateful.next_item(cat) + end + measure_all( + comparison, + name, + cat, + :before_next_item + ) + timed_next_item = @timed Stateful.next_item(cat) + next_item = timed_next_item.value + measure_all( + comparison, + name, + cat, + :after_next_item, + next_item = next_item, + timing = timed_next_item + ) + @info "next_item" timed_next_item.time strategy.time_limit + if timed_next_item.time < strategy.time_limit + push!(next_current_cats, name => cat) + end end - measure_all(config, :before_next_item, before_next_item; responses = responses) - timed_next_item = @timed config.next_item(responses, item_bank) - next_item = timed_next_item.value - measure_all(config, :after_next_item, after_next_item; - responses = responses, next_item = next_item) + current_cats, next_current_cats = next_current_cats, current_cats end end diff --git a/src/Stateful.jl b/src/Stateful.jl index 1a7563f..f34ac38 100644 --- a/src/Stateful.jl +++ b/src/Stateful.jl @@ -59,7 +59,7 @@ end struct StatefulCatConfig{ItemBankT <: AbstractItemBank} <: StatefulCat rules::CatRules tracked_responses::TrackedResponses - item_bank::ItemBankT + item_bank::Ref{ItemBankT} end function StatefulCatConfig(rules, item_bank) @@ -69,26 +69,27 @@ function StatefulCatConfig(rules, item_bank) item_bank, rules.ability_tracker ) - return StatefulCatConfig(rules, tracked_responses, item_bank) + return StatefulCatConfig(rules, tracked_responses, Ref(item_bank)) end function next_item(config::StatefulCatConfig) - return best_item(config.rules.next_item, config.tracked_responses, config.item_bank) + return best_item(config.rules.next_item, config.tracked_responses, config.item_bank[]) end function ranked_items(config::StatefulCatConfig) return sortperm(compute_criteria( - config.rules.next_item, config.tracked_responses, config.item_bank)) + config.rules.next_item, config.tracked_responses, config.item_bank[])) end function item_criteria(config::StatefulCatConfig) return compute_criteria( - config.rules.next_item, config.tracked_responses, config.item_bank) + config.rules.next_item, config.tracked_responses, config.item_bank[]) end function add_response!(config::StatefulCatConfig, index, response) Aggregators.add_response!( - config.tracked_responses, Response(ResponseType(config.item_bank), index, response)) + config.tracked_responses, Response( + ResponseType(config.item_bank[]), index, response)) end function rollback!(config::StatefulCatConfig) @@ -99,6 +100,11 @@ function reset!(config::StatefulCatConfig) empty!(config.tracked_responses) end +function set_item_bank!(config::StatefulCatConfig, item_bank) + reset!(config) + config.item_bank[] = item_bank +end + function get_responses(config::StatefulCatConfig) return config.tracked_responses.responses end From 42a085b8542d10f72bfa3bc39a0cb877e72ed764 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Tue, 25 Mar 2025 12:32:25 +0200 Subject: [PATCH 7/8] Refactor next item rules including adding pointwise utils and switch tests away from xunit --- Project.toml | 2 +- src/TerminationConditions.jl | 3 +- src/next_item_rules/NextItemRules.jl | 4 + src/next_item_rules/combinators/likelihood.jl | 19 +++ src/next_item_rules/criteria/pointwise/kl.jl | 38 +++++ .../criteria/state/ability_variance.jl | 7 +- src/next_item_rules/porcelain/aliases.jl | 8 +- src/next_item_rules/porcelain/porcelain.jl | 10 ++ test/Project.toml | 1 - test/ability_estimator_1d.jl | 103 +++++++------ test/ability_estimator_2d.jl | 143 +++++++++--------- test/aqua.jl | 24 +-- test/dt.jl | 2 +- test/format.jl | 2 +- test/jet.jl | 2 +- test/runtests.jl | 37 ++++- test/smoke.jl | 4 +- test/stateful.jl | 2 +- test/tests_top.jl | 56 ------- 19 files changed, 261 insertions(+), 206 deletions(-) create mode 100644 src/next_item_rules/combinators/likelihood.jl create mode 100644 src/next_item_rules/criteria/pointwise/kl.jl diff --git a/Project.toml b/Project.toml index 62fb79f..185950a 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ LogarithmicNumbers = "1" MacroTools = "^0.5.6" Measurements = "^2.10.0" OrderedCollections = "^1.6" -PsychometricsBazaarBase = "^0.8.0" +PsychometricsBazaarBase = "^0.8.1" Reexport = "1" Setfield = "^1" StaticArrays = "1" diff --git a/src/TerminationConditions.jl b/src/TerminationConditions.jl index 8b729ac..17ff07c 100644 --- a/src/TerminationConditions.jl +++ b/src/TerminationConditions.jl @@ -41,8 +41,7 @@ function (condition::SimpleFunctionTerminationCondition)(responses::TrackedRespo end struct RunForeverTerminationCondition <: TerminationCondition end -function (condition::RunForeverTerminationCondition)(responses::TrackedResponses, - items::AbstractItemBank) +function (condition::RunForeverTerminationCondition)(::TrackedResponses, ::AbstractItemBank) return false end diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index fa76091..e3b6d47 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -47,6 +47,7 @@ export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer export AbilityCovarianceStateMultiCriterion, StateMultiCriterion, ItemMultiCriterion export InformationMatrixCriteria export ScalarizedStateCriteron, ScalarizedItemCriteron +export DRuleItemCriterion, TRuleItemCriterion # Prelude include("./prelude/abstract.jl") @@ -61,6 +62,7 @@ include("./strategies/exhaustive.jl") # Combinators include("./combinators/expectation.jl") include("./combinators/scalarizers.jl") +include("./combinators/likelihood.jl") # Criteria include("./criteria/item/information_special.jl") @@ -68,8 +70,10 @@ include("./criteria/item/information_support.jl") include("./criteria/item/information.jl") include("./criteria/item/urry.jl") include("./criteria/state/ability_variance.jl") +include("./criteria/pointwise/kl.jl") # Porcelain +include("./porcelain/porcelain.jl") include("./porcelain/aliases.jl") end diff --git a/src/next_item_rules/combinators/likelihood.jl b/src/next_item_rules/combinators/likelihood.jl new file mode 100644 index 0000000..03da6a6 --- /dev/null +++ b/src/next_item_rules/combinators/likelihood.jl @@ -0,0 +1,19 @@ +struct LikelihoodWeightedItemCriterion{ + PointwiseItemCriterionT <: PointwiseItemCriterion, + AbilityIntegratorT <: AbilityIntegrator, + AbilityEstimatorT <: DistributionAbilityEstimator +} <: ItemCriterion + criterion::PointwiseItemCriterionT + integrator::AbilityIntegratorT + estimator::AbilityEstimatorT +end + +function compute_criterion( + lwic::LikelihoodWeightedItemCriterion, + tracked_responses::TrackedResponses, + item_idx +) + func = FunctionProduct( + pdf(lwic.estimator, tracked_responses), lwic.criterion(tracked_responses, item_idx)) + lwic.integrator(func, 0, lwic.estimator, tracked_responses) +end diff --git a/src/next_item_rules/criteria/pointwise/kl.jl b/src/next_item_rules/criteria/pointwise/kl.jl new file mode 100644 index 0000000..630680c --- /dev/null +++ b/src/next_item_rules/criteria/pointwise/kl.jl @@ -0,0 +1,38 @@ +function kl(item_response::ItemResponse, r0, theta) + r = resp_vec(item_response, theta) + resp = 0.0 + for (p0, p) in zip(r0, r) + resp += p0 * (log(p0) - log(p)) + end + return resp +end + +struct PosteriorExpectedKLInformationItemCriterion{ + PointEstimatorT <: PointAbilityEstimator, + DistributionEstimatorT <: DistributionAbilityEstimator, + IntegratorT <: AbilityIntegrator +} <: PointwiseItemCriterion +end + +function PosteriorExpectedKLInformationItemCriterion(bits...) + @requiresome point_estimator = PointAbilityEstimator(bits...) + @requiresome distribution_estimator = DistributionAbilityEstimator(bits...) + @requiresome integrator = AbilityIntegrator(bits...) + PosteriorExpectedKLInformationItemCriterion( + point_estimator, distribution_estimator, integrator) +end + +function compute_pointwise_criterion( + item_criterion::PosteriorExpectedKLInformationItemCriterion, + tracked_responses::TrackedResponses, + item_idx) + theta_0 = maybe_tracked_ability_estimate(tracked_responses, + item_criterion.point_estimator) + item_response = ItemResponse(tracked_responses.item_bank, item_idx) + r0 = resp_vec(item_response, theta_0) + expectation( + theta -> kl(item_response, r0, theta), + item_criterion.integrator, + item_criterion.distribution_estimator, + tracked_responses) +end diff --git a/src/next_item_rules/criteria/state/ability_variance.jl b/src/next_item_rules/criteria/state/ability_variance.jl index 5e13a7a..343a873 100644 --- a/src/next_item_rules/criteria/state/ability_variance.jl +++ b/src/next_item_rules/criteria/state/ability_variance.jl @@ -18,7 +18,6 @@ function _get_dist_est_and_integrator(bits...) # XXX: Weakness in this initialisation system is showing now # This needs ot be explicitly passed dist_est and integrator, but this may # be burried within a MeanAbilityEstimator - @returnsome find1_instance(AbilityVarianceStateCriterion, bits) dist_est = DistributionAbilityEstimator(bits...) integrator = AbilityIntegrator(bits...) if dist_est !== nothing && integrator !== nothing @@ -27,12 +26,14 @@ function _get_dist_est_and_integrator(bits...) # So let's just handle this case individually for now # (Is this going to cause a problem with this being picked over something more appropriate?) @requiresome mean_ability_est = MeanAbilityEstimator(bits...) - return (dist_est, integrator) + return (mean_ability_est.dist_est, mean_ability_est.integrator) end function AbilityVarianceStateCriterion(bits...) skip_zero = false - @requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...) + @returnsome find1_instance(AbilityVarianceStateCriterion, bits) + @requiresome dist_est_integrator_pair = _get_dist_est_and_integrator(bits...) + (dist_est, integrator) = dist_est_integrator_pair return AbilityVarianceStateCriterion(dist_est, integrator, skip_zero) end diff --git a/src/next_item_rules/porcelain/aliases.jl b/src/next_item_rules/porcelain/aliases.jl index b67ef63..392b9a5 100644 --- a/src/next_item_rules/porcelain/aliases.jl +++ b/src/next_item_rules/porcelain/aliases.jl @@ -41,12 +41,8 @@ const mirtcat_next_item_aliases = Dict( "MEPV" => _mirtcat_helper((bits, ability_estimator) -> ExpectationBasedItemCriterion( ability_estimator, AbilityVarianceStateCriterion(bits...))), - "Drule" => _mirtcat_helper((bits, ability_estimator) -> ScalarizedItemCriteron( - InformationMatrixCriteria(ability_estimator), - DeterminantScalarizer())), - "Trule" => _mirtcat_helper((bits, ability_estimator) -> ScalarizedItemCriteron( - InformationMatrixCriteria(ability_estimator), - TraceScalarizer())) + "Drule" => _mirtcat_helper((bits, ability_estimator) -> DRuleItemCriteron(ability_estimator)), + "Trule" => _mirtcat_helper((bits, ability_estimator) -> TRuleItemCriteron(ability_estimator)) ) # 'MLWI' for maximum likelihood weighted information diff --git a/src/next_item_rules/porcelain/porcelain.jl b/src/next_item_rules/porcelain/porcelain.jl index 8b13789..dd43165 100644 --- a/src/next_item_rules/porcelain/porcelain.jl +++ b/src/next_item_rules/porcelain/porcelain.jl @@ -1 +1,11 @@ +function DRuleItemCriterion(ability_estimator) + ScalarizedItemCriteron( + InformationMatrixCriteria(ability_estimator), + DeterminantScalarizer()) +end +function TRuleItemCriterion(ability_estimator) + ScalarizedItemCriteron( + InformationMatrixCriteria(ability_estimator), + TraceScalarizer()) +end diff --git a/test/Project.toml b/test/Project.toml index 61ce277..2b60c28 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,7 +11,6 @@ PsychometricsBazaarBase = "b0d9cada-d963-45e9-a4c6-4746243987f1" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -XUnit = "3e3c03f2-1a94-11e9-2981-050a4ca824ab" [compat] Aqua = "0.5.5, 0.6.5" diff --git a/test/ability_estimator_1d.jl b/test/ability_estimator_1d.jl index 2ae2dde..7c80037 100644 --- a/test/ability_estimator_1d.jl +++ b/test/ability_estimator_1d.jl @@ -38,64 +38,67 @@ map_1d = ModeAbilityEstimator(pa_est_1d, optimizer_1d) mle_mean_1d = MeanAbilityEstimator(lh_est_1d, integrator_1d) mle_mode_1d = ModeAbilityEstimator(lh_est_1d, optimizer_1d) -@testcase "Estimator: single dim MAP" begin - @test map_1d(tracked_responses_1d)≈1.0 atol=0.001 -end +@testset "abilest_1d" begin + @testset "Estimator: single dim MAP" begin + @test map_1d(tracked_responses_1d)≈1.0 atol=0.001 + end -@testcase "Estimator: single dim EAP" begin - @test eap_1d(tracked_responses_1d)≈1.0 atol=0.001 -end + @testset "Estimator: single dim EAP" begin + @test eap_1d(tracked_responses_1d)≈1.0 atol=0.001 + end -@testcase "Estimator: single mle mean" begin - @test mle_mean_1d(tracked_responses_1d)≈1.0 atol=0.001 -end + @testset "Estimator: single mle mean" begin + @test mle_mean_1d(tracked_responses_1d)≈1.0 atol=0.001 + end -@testcase "Estimator: single mle mode" begin - @test mle_mode_1d(tracked_responses_1d)≈1.0 atol=0.001 -end + @testset "Estimator: single mle mode" begin + @test mle_mode_1d(tracked_responses_1d)≈1.0 atol=0.001 + end -information_item_criterion = InformationItemCriterion(mle_mean_1d) + information_item_criterion = InformationItemCriterion(mle_mean_1d) -@testcase "1 dim neg information smaller closer to current estimate" begin - @test ( - compute_criterion(information_item_criterion, tracked_responses_1d, 5) < - compute_criterion(information_item_criterion, tracked_responses_1d, 6) - ) -end + @testset "1 dim neg information smaller closer to current estimate" begin + @test ( + compute_criterion(information_item_criterion, tracked_responses_1d, 5) < + compute_criterion(information_item_criterion, tracked_responses_1d, 6) + ) + end -@testcase "1 dim neg information smaller with igher discrimination" begin - @test ( - compute_criterion(information_item_criterion, tracked_responses_1d, 7) < - compute_criterion(information_item_criterion, tracked_responses_1d, 5) < - compute_criterion(information_item_criterion, tracked_responses_1d, 8) - ) -end - -ability_variance_state_criterion = AbilityVarianceStateCriterion(lh_est_1d, integrator_1d) -ability_variance_item_criterion = ExpectationBasedItemCriterion( - mle_mean_1d, - ability_variance_state_criterion -) + @testset "1 dim neg information smaller with igher discrimination" begin + @test ( + compute_criterion(information_item_criterion, tracked_responses_1d, 7) < + compute_criterion(information_item_criterion, tracked_responses_1d, 5) < + compute_criterion(information_item_criterion, tracked_responses_1d, 8) + ) + end -@testcase "postposterior 1 dim variance smaller closer to current estimate" begin - @test ( - compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 5) < - compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 6) + ability_variance_state_criterion = AbilityVarianceStateCriterion( + lh_est_1d, integrator_1d) + ability_variance_item_criterion = ExpectationBasedItemCriterion( + mle_mean_1d, + ability_variance_state_criterion ) -end -@testcase "postposterior 1 dim variance smaller with higher discrimination" begin - @test ( - compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 7) < - compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 5) < - compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 8) - ) -end + @testset "postposterior 1 dim variance smaller closer to current estimate" begin + @test ( + compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 5) < + compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 6) + ) + end + + @testset "postposterior 1 dim variance smaller with higher discrimination" begin + @test ( + compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 7) < + compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 5) < + compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 8) + ) + end -@testcase "1 dim variance decreases with new responses" begin - orig_var = compute_criterion(ability_variance_state_criterion, tracked_responses_1d) - next_responses = deepcopy(tracked_responses_1d) - add_response!(next_responses, Response(ResponseType(item_bank_1d), 5, 0)) - new_var = compute_criterion(ability_variance_state_criterion, next_responses) - @test new_var < orig_var + @testset "1 dim variance decreases with new responses" begin + orig_var = compute_criterion(ability_variance_state_criterion, tracked_responses_1d) + next_responses = deepcopy(tracked_responses_1d) + add_response!(next_responses, Response(ResponseType(item_bank_1d), 5, 0)) + new_var = compute_criterion(ability_variance_state_criterion, next_responses) + @test new_var < orig_var + end end diff --git a/test/ability_estimator_2d.jl b/test/ability_estimator_2d.jl index ab2f465..6929bea 100644 --- a/test/ability_estimator_2d.jl +++ b/test/ability_estimator_2d.jl @@ -40,72 +40,79 @@ map_2d = ModeAbilityEstimator(pa_est_2d, optimizer_2d) mle_mean_2d = MeanAbilityEstimator(lh_est_2d, integrator_2d) mle_mode_2d = ModeAbilityEstimator(lh_est_2d, optimizer_2d) -@testcase "Estimator: 2 dim MAP" begin - @test map_2d(tracked_responses_2d)≈[1.0, 1.0] atol=0.001 -end - -@testcase "Estimator: 2 dim EAP" begin - @test eap_2d(tracked_responses_2d)≈[1.0, 1.0] atol=0.001 -end - -@testcase "Estimator: 2 mle mean" begin - ans = mle_mean_2d(tracked_responses_2d) - @test ans[1] - ans[2]≈0.0 atol=0.001 -end - -@testcase "Estimator: 2 mle mode" begin - ans = mle_mode_2d(tracked_responses_2d) - @test ans[1] + ans[2]≈2.0 atol=0.001 -end - -@testcase "2 dim information higher closer to current estimate" begin - information_matrix_criteria = InformationMatrixCriteria(mle_mean_2d) - information_criterion = ScalarizedItemCriteron( - information_matrix_criteria, DeterminantScalarizer()) - - # Item closer to the current estimate (1, 1) - close_item = 5 - # Item further from the current estimate - far_item = 6 - - close_info = compute_criterion(information_criterion, tracked_responses_2d, close_item) - far_info = compute_criterion(information_criterion, tracked_responses_2d, far_item) - - @test close_info > far_info -end - -@testcase "2 dim variance smaller closer to current estimate" begin - covariance_state_criterion = AbilityCovarianceStateMultiCriterion( - lh_est_2d, integrator_2d) - variance_criterion = ScalarizedStateCriteron( - covariance_state_criterion, DeterminantScalarizer()) - variance_item_criterion = ExpectationBasedItemCriterion(mle_mean_2d, variance_criterion) - - # Item closer to the current estimate (1, 1) - close_item = 5 - # Item further from the current estimate - far_item = 6 - - close_var = compute_criterion(variance_item_criterion, tracked_responses_2d, close_item) - far_var = compute_criterion(variance_item_criterion, tracked_responses_2d, far_item) - - @test close_var < far_var -end - -@testcase "2 dim variance is whack with trace scalarizer" begin - covariance_state_criterion = AbilityCovarianceStateMultiCriterion( - lh_est_2d, integrator_2d) - variance_criterion = ScalarizedStateCriteron( - covariance_state_criterion, TraceScalarizer()) - variance_item_criterion = ExpectationBasedItemCriterion(mle_mean_2d, variance_criterion) - - # Item closer to the current estimate (1, 1) - close_item = 5 - # Item further from the current estimate - far_item = 6 - - close_var = compute_criterion(variance_item_criterion, tracked_responses_2d, close_item) - far_var = compute_criterion(variance_item_criterion, tracked_responses_2d, far_item) - - @test far_var < close_var +@testset "abilest_2d" begin + @testset "Estimator: 2 dim MAP" begin + @test map_2d(tracked_responses_2d)≈[1.0, 1.0] atol=0.001 + end + + @testset "Estimator: 2 dim EAP" begin + @test eap_2d(tracked_responses_2d)≈[1.0, 1.0] atol=0.001 + end + + @testset "Estimator: 2 mle mean" begin + ans = mle_mean_2d(tracked_responses_2d) + @test ans[1] - ans[2]≈0.0 atol=0.001 + end + + @testset "Estimator: 2 mle mode" begin + ans = mle_mode_2d(tracked_responses_2d) + @test ans[1] + ans[2]≈2.0 atol=0.001 + end + + @testset "2 dim information higher closer to current estimate" begin + information_matrix_criteria = InformationMatrixCriteria(mle_mean_2d) + information_criterion = ScalarizedItemCriteron( + information_matrix_criteria, DeterminantScalarizer()) + + # Item closer to the current estimate (1, 1) + close_item = 5 + # Item further from the current estimate + far_item = 6 + + close_info = compute_criterion( + information_criterion, tracked_responses_2d, close_item) + far_info = compute_criterion(information_criterion, tracked_responses_2d, far_item) + + @test close_info > far_info + end + + @testset "2 dim variance smaller closer to current estimate" begin + covariance_state_criterion = AbilityCovarianceStateMultiCriterion( + lh_est_2d, integrator_2d) + variance_criterion = ScalarizedStateCriteron( + covariance_state_criterion, DeterminantScalarizer()) + variance_item_criterion = ExpectationBasedItemCriterion( + mle_mean_2d, variance_criterion) + + # Item closer to the current estimate (1, 1) + close_item = 5 + # Item further from the current estimate + far_item = 6 + + close_var = compute_criterion( + variance_item_criterion, tracked_responses_2d, close_item) + far_var = compute_criterion(variance_item_criterion, tracked_responses_2d, far_item) + + @test close_var < far_var + end + + @testset "2 dim variance is whack with trace scalarizer" begin + covariance_state_criterion = AbilityCovarianceStateMultiCriterion( + lh_est_2d, integrator_2d) + variance_criterion = ScalarizedStateCriteron( + covariance_state_criterion, TraceScalarizer()) + variance_item_criterion = ExpectationBasedItemCriterion( + mle_mean_2d, variance_criterion) + + # Item closer to the current estimate (1, 1) + close_item = 5 + # Item further from the current estimate + far_item = 6 + + close_var = compute_criterion( + variance_item_criterion, tracked_responses_2d, close_item) + far_var = compute_criterion(variance_item_criterion, tracked_responses_2d, far_item) + + @test far_var < close_var + end end diff --git a/test/aqua.jl b/test/aqua.jl index 229f3f4..4dc76f1 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -1,15 +1,17 @@ using Aqua using ComputerAdaptiveTesting -Aqua.test_all( - ComputerAdaptiveTesting; - ambiguities = false -) -# Ambiguities are not tested in default configuration as a workaround for -# https://github.com/JuliaTesting/Aqua.jl/issues/77 -# Core is not included because of Core.Number, namely +@testset "aqua" begin + Aqua.test_all( + ComputerAdaptiveTesting; + ambiguities = false + ) + # Ambiguities are not tested in default configuration as a workaround for + # https://github.com/JuliaTesting/Aqua.jl/issues/77 + # Core is not included because of Core.Number, namely -# ComputerAdaptiveTesting gets errors from FowardDiff extending Core.Number -# Could possibly get some of these fixed in ForwardDiff eventually? -# https://github.com/JuliaDiff/ForwardDiff.jl/issues/597 -Aqua.test_ambiguities([ComputerAdaptiveTesting]) + # ComputerAdaptiveTesting gets errors from FowardDiff extending Core.Number + # Could possibly get some of these fixed in ForwardDiff eventually? + # https://github.com/JuliaDiff/ForwardDiff.jl/issues/597 + Aqua.test_ambiguities([ComputerAdaptiveTesting]) +end diff --git a/test/dt.jl b/test/dt.jl index 0372378..8c89b07 100644 --- a/test/dt.jl +++ b/test/dt.jl @@ -8,7 +8,7 @@ integrator = FunctionIntegrator(Integrators.even_grid(-6, 6, 61)) ability_estimator = MeanAbilityEstimator(LikelihoodAbilityEstimator(), integrator) get_response = auto_responder(@view true_responses[:, 1]) -@testcase "round trip" begin +@testset "decision tree round trip" begin next_item_rule = ItemStrategyNextItemRule( AbilityVarianceStateCriterion( distribution_estimator(ability_estimator), integrator), diff --git a/test/format.jl b/test/format.jl index efd4127..9a73ca2 100644 --- a/test/format.jl +++ b/test/format.jl @@ -1,7 +1,7 @@ using JuliaFormatter using ComputerAdaptiveTesting -@testcase "format" begin +@testset "format" begin dir = pkgdir(ComputerAdaptiveTesting) @test format(dir * "/src"; overwrite = false) @test format(dir * "/test"; overwrite = false) diff --git a/test/jet.jl b/test/jet.jl index 61f13be..aa3541f 100644 --- a/test/jet.jl +++ b/test/jet.jl @@ -1,7 +1,7 @@ using JET using Optim: Optim -@testset "JET checks" begin +@testset "jet" begin rep = report_package( ComputerAdaptiveTesting; target_modules = ( diff --git a/test/runtests.jl b/test/runtests.jl index 0c8ca81..5a60559 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,36 @@ -using XUnit +using Base.Filesystem: mktempdir +using ComputerAdaptiveTesting +using ComputerAdaptiveTesting.Aggregators +using FittedItemBanks.DummyData: dummy_full, SimpleItemBankSpec, StdModel3PL, + VectorContinuousDomain, BooleanResponse, std_normal +using FittedItemBanks +using ComputerAdaptiveTesting.CatConfig +using ComputerAdaptiveTesting.Responses +using ComputerAdaptiveTesting.NextItemRules +using ComputerAdaptiveTesting.TerminationConditions +using ComputerAdaptiveTesting.Sim +using PsychometricsBazaarBase.Integrators +using PsychometricsBazaarBase.Optimizers +using ComputerAdaptiveTesting.DecisionTree +using ComputerAdaptiveTesting: Stateful +using Distributions +using Distributions: ZeroMeanIsoNormal, Zeros, ScalMat +using Optim +using Random +using ResumableFunctions -runtests("tests_top.jl", ARGS...) +using Test + +include("./dummy.jl") +using .Dummy + +@testset "test" begin + include("./aqua.jl") + include("./jet.jl") + include("./ability_estimator_1d.jl") + include("./ability_estimator_2d.jl") + include("./smoke.jl") + include("./dt.jl") + include("./stateful.jl") + include("./format.jl") +end \ No newline at end of file diff --git a/test/smoke.jl b/test/smoke.jl index b0a9404..4a4c176 100644 --- a/test/smoke.jl +++ b/test/smoke.jl @@ -1,6 +1,6 @@ #(item_bank, abilities, responses) = dummy_full(Random.default_rng(42), SimpleItemBankSpec(StdModel4PL(), VectorContinuousDomain(), BooleanResponse()), 2; num_questions=100, num_testees=3) -@testcase "Smoke test 1d" begin +@testset "Smoke test 1d" begin (item_bank, abilities, true_responses) = dummy_full( Random.default_rng(42), SimpleItemBankSpec(StdModel3PL(), OneDimContinuousDomain(), BooleanResponse()); @@ -44,7 +44,7 @@ end #= -@testcase "Smoke test 2d" begin +@testset "Smoke test 2d" begin Random.seed!(42) (item_bank, abilities, responses) = dummy_mirt_4pl(2; num_questions=4, num_testees=2) end diff --git a/test/stateful.jl b/test/stateful.jl index 4e90f1d..f01a470 100644 --- a/test/stateful.jl +++ b/test/stateful.jl @@ -1,4 +1,4 @@ -@testcase "Stateful" begin +@testset "Stateful" begin rng = Random.default_rng(42) # Create test data diff --git a/test/tests_top.jl b/test/tests_top.jl index 81e2511..8b13789 100644 --- a/test/tests_top.jl +++ b/test/tests_top.jl @@ -1,57 +1 @@ -using Base.Filesystem: mktempdir -using ComputerAdaptiveTesting -using ComputerAdaptiveTesting.Aggregators -using FittedItemBanks.DummyData: dummy_full, SimpleItemBankSpec, StdModel3PL, - VectorContinuousDomain, BooleanResponse, std_normal -using FittedItemBanks -using ComputerAdaptiveTesting.CatConfig -using ComputerAdaptiveTesting.Responses -using ComputerAdaptiveTesting.NextItemRules -using ComputerAdaptiveTesting.TerminationConditions -using ComputerAdaptiveTesting.Sim -using PsychometricsBazaarBase.Integrators -using PsychometricsBazaarBase.Optimizers -using ComputerAdaptiveTesting.DecisionTree -using ComputerAdaptiveTesting: Stateful -using Distributions -using Distributions: ZeroMeanIsoNormal, Zeros, ScalMat -using Optim -using Random -using ResumableFunctions -using XUnit - -include("./dummy.jl") -using .Dummy - -@testset "aqua" begin - include("./aqua.jl") -end - -@testset "jet" begin - include("./jet.jl") -end - -@testset "abilest_1d" begin - include("./ability_estimator_1d.jl") -end - -@testset "abilest_2d" begin - include("./ability_estimator_2d.jl") -end - -@testset "smoke" begin - include("./smoke.jl") -end - -@testset "dt" begin - include("./dt.jl") -end - -@testset "stateful" begin - include("./stateful.jl") -end - -@testset "format" begin - include("./format.jl") -end From 7a0ada64b3998eaf65b80aa4828f14ceb8c8c815 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Tue, 25 Mar 2025 13:43:04 +0200 Subject: [PATCH 8/8] Fix up format problem --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 5a60559..18bc266 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,4 +33,4 @@ using .Dummy include("./dt.jl") include("./stateful.jl") include("./format.jl") -end \ No newline at end of file +end