Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ FittedItemBanks = "3f797b09-34e4-41d7-acf6-3302ae3248a5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
Expand Down Expand Up @@ -48,13 +47,12 @@ FittedItemBanks = "^0.6.3"
ForwardDiff = "0.10.24"
HypothesisTests = "^0.10.12, ^0.11.0"
Interpolations = "^0.14, ^0.15"
KernelAbstractions = "^0.9.22"
Lazy = "0.15"
LogarithmicNumbers = "1"
MacroTools = "^0.5.6"
Measurements = "^2.10.0"
OrderedCollections = "^1.6"
PsychometricsBazaarBase = "^0.8.0"
PsychometricsBazaarBase = "^0.8.1"
Reexport = "1"
Setfield = "^1"
StaticArrays = "1"
Expand Down
116 changes: 87 additions & 29 deletions src/Comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Comparison
# Should be kept in mind and kept distinct or code reuse

using StatsBase
using FittedItemBanks: AbstractItemBank, ResponseType
using FittedItemBanks: AbstractItemBank, ResponseType, subset
using ..Responses
using ..CatConfig: CatLoopConfig, CatRules
using ..Aggregators: TrackedResponses, add_response!, Speculator, Aggregators, track!,
Expand All @@ -14,11 +14,11 @@ using Base: Iterators

using HypothesisTests
using EffectSizes
using DataFrames
using DataFrames: DataFrame
using ComputerAdaptiveTesting: Stateful

export run_random_comparison, run_comparison
export CatComparisonExecutionStrategy#, IncreaseItemBankSizeExecutionStrategy
export CatComparisonExecutionStrategy, IncreaseItemBankSizeExecutionStrategy
#export FollowOneExecutionStrategy, RunIndependentlyExecutionStrategy
#export DecisionTreeExecutionStrategy
export ReplayResponsesExecutionStrategy
Expand Down Expand Up @@ -83,7 +83,8 @@ end

abstract type CatComparisonExecutionStrategy end

Base.@kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrategy}
struct CatComparisonConfig{
StrategyT <: CatComparisonExecutionStrategy, PhasesT <: NamedTuple}
"""
A named tuple with the (named) CatRules (or compatable) to be compared
"""
Expand All @@ -99,13 +100,42 @@ Base.@kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrate
measurements::Vector{}
=#
"""
Which phases to run and/or call the callback on
The phases to run, optionally paired with a callback
"""
phases::Set{Symbol} = Set((:before_next_item, :after_next_item))
"""
The callback which should take a named tuple with information at different phases
"""
callback::Any
phases::PhasesT
end

"""
CatComparisonConfig(;
rules::NamedTuple{Symbol, StatefulCat},
strategy::CatComparisonExecutionStrategy,
phases::Union{NamedTuple{Symbol, Callable}, Tuple{Symbol}},
callback::Callable
) -> CatComparisonConfig

CatComparisonConfig sets up a evaluation-oriented comparison between different CAT systems.

Specify the comparison by listing: CAT systems in `rules`, a `NamedTuple` which gives
identifiers to implementations of the `StatefulCat` interface; the `strategy` to use,
an implementation of `CatComparisonExecutionStrategy`; the `phases` to run listed as
either as a `NamedTuple` with names of phases and corresponding callbacks or `nothing` a
`Tuple` of phases to run; and a `callback` which will be used as a fallback in cases where
no callback is provided.

The exact phases depend on the strategy used. See their individual documentation for more.
"""
function CatComparisonConfig(; rules, strategy, phases = nothing, callback = nothing)
if callback === nothing
callback = (info; kwargs...) -> nothing
end
if phases === nothing
phases = (:before_next_item, :after_next_item)
end
# TODO: normalize phases into named tuple
if !(phases isa NamedTuple)
phases = NamedTuple((phase => callback for phase in phases))
end
CatComparisonConfig(rules, strategy, phases)
end

# Comparison scenarios:
Expand All @@ -129,9 +159,11 @@ end

#phase_func=nothing;
function measure_all(comparison, system, cat, phase; kwargs...)
if !(phase in comparison.phases)
@info "measure_all" phase comparison.phases
if !(phase in keys(comparison.phases))
return
end
callback = comparison.phases[phase]
strategy = comparison.strategy
#=measurement_results = []
for measurement in comparison.measurements
Expand All @@ -145,7 +177,7 @@ function measure_all(comparison, system, cat, phase; kwargs...)
#end
push!(measurement_results, result)
end=#
comparison.callback((;
callback((;
phase,
system,
cat,
Expand All @@ -158,30 +190,56 @@ struct IncreaseItemBankSizeExecutionStrategy <: CatComparisonExecutionStrategy
item_bank::AbstractItemBank
sizes::AbstractVector{Int}
starting_responses::Int
shuffle::Bool
time_limit::Float64

function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, args...)
if any((size > length(item_bank) for size in sizes))
error("IncreaseItemBankSizeExecutionStrategy: No subset size can be greater than the number of items available in the item bank")
end
new(item_bank, sizes, args...)
end
end

function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes)
return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0)
return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0, false, Inf)
end

function run_comparison(strategy::IncreaseItemBankSizeExecutionStrategy, config)
function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExecutionStrategy})
strategy = comparison.strategy
current_cats = collect(pairs(comparison.rules))
next_current_cats = copy(current_cats)
@info "sizes" strategy.sizes
for size in strategy.sizes
subsetted_item_bank = subset(strategy.item_bank, size)
responses = TrackedResponses(
BareResponses(ResponseType(strategy.item_bank)),
subsetted_item_bank,
config.ability_tracker
)
for _ in 1:(strategy.starting_responses)
next_item = config.next_item(responses, subsetted_item_bank)
add_response!(responses,
Response(ResponseType(subsetted_item_bank), next_item, rand(Bool)))
subsetted_item_bank = subset(strategy.item_bank, 1:size)
empty!(next_current_cats)
for (name, cat) in current_cats
Stateful.set_item_bank!(cat, subsetted_item_bank)
for _ in 1:(strategy.starting_responses)
Stateful.next_item(cat)
end
measure_all(
comparison,
name,
cat,
:before_next_item
)
timed_next_item = @timed Stateful.next_item(cat)
next_item = timed_next_item.value
measure_all(
comparison,
name,
cat,
:after_next_item,
next_item = next_item,
timing = timed_next_item
)
@info "next_item" timed_next_item.time strategy.time_limit
if timed_next_item.time < strategy.time_limit
push!(next_current_cats, name => cat)
end
end
measure_all(config, :before_next_item, before_next_item; responses = responses)
timed_next_item = @timed config.next_item(responses, item_bank)
next_item = timed_next_item.value
measure_all(config, :after_next_item, after_next_item;
responses = responses, next_item = next_item)
current_cats, next_current_cats = next_current_cats, current_cats
end
end

Expand Down
18 changes: 12 additions & 6 deletions src/Stateful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
struct StatefulCatConfig{ItemBankT <: AbstractItemBank} <: StatefulCat
rules::CatRules
tracked_responses::TrackedResponses
item_bank::ItemBankT
item_bank::Ref{ItemBankT}
end

function StatefulCatConfig(rules, item_bank)
Expand All @@ -69,26 +69,27 @@ function StatefulCatConfig(rules, item_bank)
item_bank,
rules.ability_tracker
)
return StatefulCatConfig(rules, tracked_responses, item_bank)
return StatefulCatConfig(rules, tracked_responses, Ref(item_bank))
end

function next_item(config::StatefulCatConfig)
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank)
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank[])
end

function ranked_items(config::StatefulCatConfig)
return sortperm(compute_criteria(
config.rules.next_item, config.tracked_responses, config.item_bank))
config.rules.next_item, config.tracked_responses, config.item_bank[]))
end

function item_criteria(config::StatefulCatConfig)
return compute_criteria(
config.rules.next_item, config.tracked_responses, config.item_bank)
config.rules.next_item, config.tracked_responses, config.item_bank[])
end

function add_response!(config::StatefulCatConfig, index, response)
Aggregators.add_response!(
config.tracked_responses, Response(ResponseType(config.item_bank), index, response))
config.tracked_responses, Response(
ResponseType(config.item_bank[]), index, response))
end

function rollback!(config::StatefulCatConfig)
Expand All @@ -99,6 +100,11 @@ function reset!(config::StatefulCatConfig)
empty!(config.tracked_responses)
end

function set_item_bank!(config::StatefulCatConfig, item_bank)
reset!(config)
config.item_bank[] = item_bank
end

function get_responses(config::StatefulCatConfig)
return config.tracked_responses.responses
end
Expand Down
3 changes: 1 addition & 2 deletions src/TerminationConditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ function (condition::SimpleFunctionTerminationCondition)(responses::TrackedRespo
end

struct RunForeverTerminationCondition <: TerminationCondition end
function (condition::RunForeverTerminationCondition)(responses::TrackedResponses,
items::AbstractItemBank)
function (condition::RunForeverTerminationCondition)(::TrackedResponses, ::AbstractItemBank)
return false
end

Expand Down
12 changes: 8 additions & 4 deletions src/aggregators/Aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import PsychometricsBazaarBase.IntegralCoeffs

export AbilityEstimator, TrackedResponses
export AbilityTracker, NullAbilityTracker, PointAbilityTracker, GriddedAbilityTracker
export ClosedFormNormalAbilityTracker, MultiAbilityTracker, track!
export ClosedFormNormalAbilityTracker, track!
export response_expectation,
add_response!, pop_response!, expectation, distribution_estimator
export PointAbilityEstimator, PriorAbilityEstimator, LikelihoodAbilityEstimator
Expand Down Expand Up @@ -91,12 +91,16 @@ function AbilityTracker(bits...; integrator = nothing, ability_estimator = nothi
end
end

function compatible_tracker(bits...; integrator, ability_estimator, prefer_tracked)
ability_tracker = AbilityTracker(bits...; ability_estimator = ability_estimator)
if ability_tracker isa GriddedAbilityTracker &&
function find_ability_tracker(ability_tracker, typ, integrator)
if ability_tracker isa typ &&
ability_tracker.integrator === integrator
return ability_tracker
end
end

function compatible_tracker(bits...; integrator, ability_estimator, prefer_tracked)
ability_tracker = AbilityTracker(bits...; ability_estimator = ability_estimator)
@returnsome find_ability_tracker(ability_tracker, GriddedAbilityTracker, integrator)
if prefer_tracked
return AbilityTracker(bits...;
integrator = integrator,
Expand Down
1 change: 0 additions & 1 deletion src/aggregators/ability_tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ include("./ability_trackers/grid.jl")
include("./ability_trackers/point.jl")
include("./ability_trackers/closed_form_normal.jl")
include("./ability_trackers/laplace.jl")
include("./ability_trackers/multi.jl")

"""
This method returns a tracked point estimate if it is has the given ability
Expand Down
8 changes: 6 additions & 2 deletions src/aggregators/ability_trackers/grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ end

function GriddedAbilityTracker(ability_estimator::DistributionAbilityEstimator,
integrator::FixedGridIntegrator)
GriddedAbilityTracker(ability_estimator, integrator, fill(NaN, length(integrator.grid)))
GriddedAbilityTracker(ability_estimator, integrator, fill(1.0, length(integrator.grid)))
end

find_grid(integrator::FixedGridIntegrator) = integrator.grid
find_grid(integrator::PreallocatedFixedGridIntegrator) = integrator.inner.grid

function track!(responses, ability_tracker::GriddedAbilityTracker)
ability_pdf = pdf(ability_tracker.ability_estimator, responses)
ability_tracker.cur_ability .= ability_pdf.(ability_tracker.integrator.grid)
grid = find_grid(ability_tracker.integrator)
ability_tracker.cur_ability .= ability_pdf.(grid)
end
9 changes: 0 additions & 9 deletions src/aggregators/ability_trackers/multi.jl

This file was deleted.

7 changes: 4 additions & 3 deletions src/next_item_rules/NextItemRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer
export AbilityCovarianceStateMultiCriterion, StateMultiCriterion, ItemMultiCriterion
export InformationMatrixCriteria
export ScalarizedStateCriteron, ScalarizedItemCriteron
export DRuleItemCriterion, TRuleItemCriterion

# Prelude
include("./prelude/abstract.jl")
Expand All @@ -61,18 +62,18 @@ include("./strategies/exhaustive.jl")
# Combinators
include("./combinators/expectation.jl")
include("./combinators/scalarizers.jl")
include("./combinators/likelihood.jl")

# Criteria
include("./criteria/item/information_special.jl")
include("./criteria/item/information_support.jl")
include("./criteria/item/information.jl")
include("./criteria/item/urry.jl")
include("./criteria/state/ability_variance.jl")
include("./criteria/pointwise/kl.jl")

# Porcelain
include("./porcelain/porcelain.jl")
include("./porcelain/aliases.jl")

# Experimental
include("./experimental/ka.jl")

end
19 changes: 19 additions & 0 deletions src/next_item_rules/combinators/likelihood.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
struct LikelihoodWeightedItemCriterion{
PointwiseItemCriterionT <: PointwiseItemCriterion,
AbilityIntegratorT <: AbilityIntegrator,
AbilityEstimatorT <: DistributionAbilityEstimator
} <: ItemCriterion
criterion::PointwiseItemCriterionT
integrator::AbilityIntegratorT
estimator::AbilityEstimatorT
end

function compute_criterion(
lwic::LikelihoodWeightedItemCriterion,
tracked_responses::TrackedResponses,
item_idx
)
func = FunctionProduct(
pdf(lwic.estimator, tracked_responses), lwic.criterion(tracked_responses, item_idx))
lwic.integrator(func, 0, lwic.estimator, tracked_responses)
end
Loading
Loading