Skip to content

Commit 6dd9ce9

Browse files
author
Frankie Robertson
committed
Add GreedyForcedContentBalancer and PointwiseNextItemRule
1 parent 3bed638 commit 6dd9ce9

File tree

7 files changed

+147
-16
lines changed

7 files changed

+147
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Distributions = "^0.25.88"
4848
DocStringExtensions = " ^0.9"
4949
EffectSizes = "^1.0.1"
5050
FillArrays = "0.13, 1.5.0"
51-
FittedItemBanks = "^0.6.3, ^0.7.0"
51+
FittedItemBanks = "^0.7.2"
5252
ForwardDiff = "1"
5353
HypothesisTests = "^0.10.12, ^0.11.0"
5454
Interpolations = "^0.14, ^0.15"

src/next_item_rules/NextItemRules.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome,
2121
find1_instance, find1_type
2222
using PsychometricsBazaarBase.Integrators: Integrator, intval
2323
using PsychometricsBazaarBase: Integrators
24+
using PsychometricsBazaarBase.IndentWrappers: indent
2425
import PsychometricsBazaarBase.IntegralCoeffs
2526
using FittedItemBanks: AbstractItemBank, DiscreteDomain, DomainType,
2627
ItemResponse, OneDimContinuousDomain, domdims, item_params,
27-
resp, resp_vec, responses
28+
resp, resp_vec, responses, subset_view
2829
using ..Aggregators
2930
using ..Aggregators: covariance_matrix, FunctionProduct
3031

@@ -34,6 +35,7 @@ using Base.Order
3435
using StaticArrays: SVector
3536
using ConstructionBase: constructorof
3637
import ForwardDiff
38+
import Base: show
3739

3840
export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread
3941
export NextItemRule, ItemStrategyNextItemRule
@@ -46,7 +48,7 @@ export EmpiricalInformationPointwiseItemCategoryCriterion
4648
export TotalItemInformation
4749
export RandomNextItemRule
4850
export PiecewiseNextItemRule, MemoryNextItemRule, FixedFirstItemNextItemRule
49-
export ExhaustiveSearch
51+
export ExhaustiveSearch, RandomesqueStrategy
5052
export preallocate
5153
export compute_criteria, compute_criterion, compute_multi_criterion
5254
export best_item
@@ -68,6 +70,8 @@ include("./strategies/random.jl")
6870
include("./strategies/randomesque.jl")
6971
include("./strategies/sequential.jl")
7072
include("./strategies/exhaustive.jl")
73+
include("./strategies/pointwise.jl")
74+
include("./strategies/balance.jl")
7175

7276
# Combinators
7377
include("./combinators/expectation.jl")

src/next_item_rules/prelude/next_item_rule.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ function best_item(rule::NextItemRule, tracked_responses::TrackedResponses)
5353
best_item(rule, tracked_responses, tracked_responses.item_bank)
5454
end
5555

56+
function Base.show(io::IO, ::MIME"text/plain", next_item_rule::ItemStrategyNextItemRule)
57+
println(io, "Strategy:")
58+
show(indent_io, MIME"text/plain"(), rules.strategy)
59+
println(io, "Item criterion:")
60+
show(indent_io, MIME"text/plain"(), rules.criterion)
61+
end
62+
5663
# Default implementation
5764
function compute_criteria(::NextItemRule, ::TrackedResponses)
5865
nothing
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
$(TYPEDEF)
3+
$(TYPEDFIELDS)
4+
5+
This content balancing procedure takes target proportions for each group of items.
6+
At each step the group with the lowest ratio of seen items to target is selected.
7+
8+
http://dx.doi.org/10.1207/s15324818ame0403_4
9+
"""
10+
struct GreedyForcedContentBalancer{InnerRuleT <: NextItemRule} <: NextItemRule
11+
targets::Vector{Float64}
12+
groups::Vector{Int}
13+
inner_rule::InnerRuleT
14+
end
15+
16+
function GreedyForcedContentBalancer(targets::Dict, groups, bits...)
17+
targets_vec = zeros(Float64, length(targets))
18+
groups_idxs = zeros(Int, length(groups))
19+
group_lookup = Dict{Any, Int}()
20+
for (idx, group) in enumerate(groups)
21+
if haskey(group_lookup, group)
22+
group_idx = group_lookup[group]
23+
else
24+
group_idx = length(group_lookup) + 1
25+
group_lookup[group] = group_idx
26+
end
27+
groups_idxs[idx] = group_idx
28+
end
29+
if length(group_lookup) != length(targets)
30+
error("Number of groups $(length(group_lookup)) does not match number of targets $(length(targets))")
31+
end
32+
for (group, group_idx) in pairs(group_lookup)
33+
targets_vec[group_idx] = get(targets, group, 0.0)
34+
end
35+
GreedyForcedContentBalancer(targets_vec, groups_idxs, bits...)
36+
end
37+
38+
function GreedyForcedContentBalancer(targets::AbstractVector, groups, bits...)
39+
GreedyForcedContentBalancer(targets, groups, NextItemRule(bits...))
40+
end
41+
42+
function show(io::IO, ::MIME"text/plain", rule::GreedyForcedContentBalancer)
43+
indent_io = indent(io, 2)
44+
println(io, "Greedy + forced content balancer")
45+
println(indent_io, "Target ratio: " * join(rule.targets, ", "))
46+
print(indent_io, "Using rule: ")
47+
show(indent_io, MIME("text/plain"), rule.inner_rule)
48+
end
49+
50+
function next_item_bank(targets, groups, responses, items)
51+
seen = zeros(UInt, size(targets))
52+
indices = responses.responses.indices
53+
for group_idx in groups[indices]
54+
seen[group_idx] += 1
55+
end
56+
next_group_idx = argmin(seen ./ targets)
57+
matching_indicator = groups .== next_group_idx
58+
next_items = subset_view(items, matching_indicator)
59+
return (next_items, matching_indicator)
60+
end
61+
62+
function best_item(
63+
rule::GreedyForcedContentBalancer,
64+
responses::TrackedResponses,
65+
items
66+
)
67+
next_items, matching_indicator = next_item_bank(rule.targets, rule.groups, responses, items)
68+
inner_idx = best_item(rule.inner_rule, responses, next_items)
69+
for (outer_idx, in_group) in enumerate(matching_indicator)
70+
if in_group
71+
inner_idx -= 1
72+
if inner_idx <= 0
73+
return outer_idx
74+
end
75+
end
76+
end
77+
error("No item found in group length $(length(next_items)) with inner index $inner_idx")
78+
end
79+
80+
function compute_criteria(
81+
rule::GreedyForcedContentBalancer,
82+
responses::TrackedResponses,
83+
items
84+
)
85+
next_items, matching_indicator = next_item_bank(rule.targets, rule.groups, responses, items)
86+
criteria = compute_criteria(rule.inner_rule, responses, next_items)
87+
expanded = fill(Inf, length(items))
88+
expanded[matching_indicator] .= criteria
89+
return expanded
90+
end

src/next_item_rules/strategies/exhaustive.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
1-
function exhaustive_search(objective::ItemCriterionT,
2-
responses::TrackedResponseT,
3-
items::AbstractItemBank)::Tuple{
4-
Int,
5-
Float64
6-
} where {ItemCriterionT <: ItemCriterion, TrackedResponseT <: TrackedResponses}
7-
#pre_next_item(expectation_tracker, items)
8-
objective_state = init_thread(objective, responses)
1+
function exhaustive_search(
2+
callback,
3+
answered_items::AbstractVector{Int},
4+
items::AbstractItemBank
5+
)::Tuple{Int, Float64}
96
min_obj_idx::Int = -1
107
min_obj_val::Float64 = Inf
118
for item_idx in eachindex(items)
129
# TODO: Add these back in
1310
#@init irf_states_storage = zeros(Int, length(responses) + 1)
14-
if (findfirst(idx -> idx == item_idx, responses.responses.indices) !== nothing)
11+
if (findfirst(idx -> idx == item_idx, answered_items) !== nothing)
1512
continue
1613
end
1714

18-
obj_val = compute_criterion(objective, objective_state, responses, item_idx)
15+
obj_val = callback(item_idx)
1916

2017
if obj_val <= min_obj_val
2118
min_obj_val = obj_val
@@ -25,6 +22,18 @@ function exhaustive_search(objective::ItemCriterionT,
2522
return (min_obj_idx, min_obj_val)
2623
end
2724

25+
function exhaustive_search(objective::ItemCriterionT,
26+
responses::TrackedResponseT,
27+
items::AbstractItemBank)::Tuple{
28+
Int,
29+
Float64
30+
} where {ItemCriterionT <: ItemCriterion, TrackedResponseT <: TrackedResponses}
31+
objective_state = init_thread(objective, responses)
32+
return exhaustive_search(responses.responses.indices, items) do item_idx
33+
return compute_criterion(objective, objective_state, responses, item_idx)
34+
end
35+
end
36+
2837
"""
2938
$(TYPEDEF)
3039
$(TYPEDFIELDS)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
struct PointwiseNextItemRule{CriterionT <: PointwiseItemCriterion, PointsT <: AbstractArray{<:Number}} <: NextItemRule
2+
criterion::CriterionT
3+
points::PointsT
4+
end
5+
6+
function best_item(rule::PointwiseNextItemRule, responses::TrackedResponses, items)
7+
num_responses = length(responses.responses.indices)
8+
next_index = num_responses + 1
9+
if next_index > length(rule.points)
10+
error("Number of responses exceeds the number of points defined in the rule.")
11+
end
12+
current_point = rule.points[next_index]
13+
idx, _ = exhaustive_search(responses.responses.indices, items) do item_idx
14+
return compute_criterion(rule.criterion, ItemResponse(items, item_idx), current_point)
15+
end
16+
return idx
17+
end
18+
19+
function PointwiseFirstNextItemRule(criterion, points, rule)
20+
PiecewiseNextItemRule((length(points),), (PointwiseNextItemRule(criterion, points), rule))
21+
end

src/next_item_rules/strategies/sequential.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ This is the most basic rule for choosing the next item in a CAT. It simply
66
picks a random item from the set of items that have not yet been
77
administered.
88
"""
9-
@kwdef struct PiecewiseNextItemRule{BreaksT, RulesT} <: NextItemRule
9+
@kwdef struct PiecewiseNextItemRule{RulesT} <: NextItemRule
1010
# Tuple of Ints
11-
breaks::BreaksT
12-
# Types of NextItemRules
11+
breaks::Tuple{Int}
12+
# Tuple of NextItemRules
1313
rules::RulesT
1414
end
1515

0 commit comments

Comments
 (0)