Skip to content

Commit 832b863

Browse files
author
Frankie Robertson
committed
Add RecordedCatLoop
1 parent 4d98d8d commit 832b863

File tree

5 files changed

+217
-2
lines changed

5 files changed

+217
-2
lines changed

src/Aggregators/Aggregators.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ export FunctionOptimizer, FunctionIntegrator
5252
export DistributionAbilityEstimator
5353
export variance, variance_given_mean, mean_1d
5454
export RiemannEnumerationIntegrator
55+
export get_integrator
5556
# export EnumerationOptimizer
5657

5758
# Basic types
@@ -200,6 +201,10 @@ struct FunctionIntegrator{IntegratorT <: Integrator} <: AbilityIntegrator
200201
integrator::IntegratorT
201202
end
202203

204+
function get_integrator(integrator::FunctionIntegrator)
205+
return integrator.integrator
206+
end
207+
203208
function (integrator::FunctionIntegrator{IntegratorT})(f::F,
204209
ncomp,
205210
lh_function::LHF) where {F, LHF, IntegratorT}

src/Aggregators/tracked.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ struct TrackedLikelihoodIntegrator{IntegratorT <: Integrator} <: AbilityIntegrat
2020
tracker::GriddedAbilityTracker
2121
end
2222

23+
function get_integrator(integrator::TrackedLikelihoodIntegrator)
24+
return integrator.integrator
25+
end
26+
2327
function (integrator::TrackedLikelihoodIntegrator{IntegratorT})(f::F,
2428
ncomp) where {F, IntegratorT}
2529
integrator.integrator(FunctionArgProduct(f), integrator.tracker.cur_ability, ncomp)

src/Sim/Sim.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@ using ElasticArrays
55
using ElasticArrays: sizehint_lastdim!
66
using DocStringExtensions
77
using StatsBase
8-
using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse
8+
using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, domdims
99
using PsychometricsBazaarBase.Integrators
1010
using PsychometricsBazaarBase.IndentWrappers: indent
1111
using ..ConfigBase
1212
using ..Responses
1313
using ..Rules: CatRules
1414
using ..Aggregators: TrackedResponses,
1515
add_response!,
16+
get_integrator,
1617
Aggregators,
1718
AbilityIntegrator,
1819
AbilityEstimator,
@@ -22,15 +23,17 @@ using ..Aggregators: TrackedResponses,
2223
MeanAbilityEstimator,
2324
LikelihoodAbilityEstimator,
2425
RiemannEnumerationIntegrator
25-
using ..NextItemRules: compute_criteria, best_item
26+
using ..NextItemRules: AbilityVariance, compute_criteria, best_item
2627
import Base: show
2728

2829
export CatRecorder, CatRecording
2930
export CatLoop, record!
31+
export RecordedCatLoop
3032
export run_cat, prompt_response, auto_responder
3133

3234
include("./recorder.jl")
3335
include("./loop.jl")
3436
include("./run.jl")
37+
include("./recorded_loop.jl")
3538

3639
end

src/Sim/recorded_loop.jl

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
struct RecordedCatLoop
2+
cat_loop::CatLoop{<: CatRules}
3+
recorder::CatRecorder
4+
item_bank::Union{AbstractItemBank, Nothing}
5+
end
6+
7+
function _prepare_get_response!(kwargs)
8+
has_responses = haskey(kwargs, :responses)
9+
has_get_response = haskey(kwargs, :get_response)
10+
if has_responses && has_get_response
11+
error("Cannot provide both `responses` and `get_response`.")
12+
elseif !has_responses && !has_get_response
13+
error("Must provide either `responses` or `get_response`.")
14+
elseif has_get_response
15+
return nothing, pop!(kwargs, :get_response)
16+
else
17+
responses = pop!(kwargs, :responses)
18+
return responses, Sim.auto_responder(responses)
19+
end
20+
end
21+
22+
function _walk_find_type(obj, typ, out=[])
23+
if obj isa typ
24+
push!(out, obj)
25+
end
26+
for fieldname in propertynames(obj)
27+
_walk_find_type(getfield(obj, fieldname), typ, out)
28+
end
29+
return out
30+
end
31+
32+
function _find_mean_ability(rules)
33+
if rules.ability_estimator isa MeanAbilityEstimator
34+
return rules.ability_estimator
35+
end
36+
result = _walk_find_type(rules.next_item, MeanAbilityEstimator)
37+
if !isempty(result)
38+
return result[1]
39+
end
40+
result = _walk_find_type(rules.termination_condition, MeanAbilityEstimator)
41+
if !isempty(result)
42+
return result[1]
43+
end
44+
return nothing
45+
end
46+
47+
function _find_ability_variance(rules)
48+
result = _walk_find_type(rules.next_item, AbilityVariance)
49+
if !isempty(result)
50+
return result[1]
51+
end
52+
return nothing
53+
end
54+
55+
function enrich_recorder_requests(old_requests, rules)
56+
requests = Dict()
57+
for (k, v) in pairs(old_requests)
58+
new_v = Dict{Symbol, Any}(pairs(v))
59+
type = get(new_v, :type, nothing)
60+
if type in (:ability, :ability_distribution, :ability_stddev)
61+
if haskey(new_v, :estimator) && haskey(new_v, :source)
62+
error("Cannot provide both `estimator` and `source` for request `$k`.")
63+
elseif !haskey(new_v, :estimator)
64+
if !haskey(new_v, :source)
65+
error("Must provide either `estimator` or `source` for request `$k`.")
66+
end
67+
source = new_v[:source]
68+
if source != :any
69+
error("Not implemented yet: `source = $source` for request `$k`.")
70+
end
71+
if type == :ability
72+
new_v[:estimator] = rules.ability_estimator
73+
elseif type == :ability_stddev
74+
error("Not implemented yet: `type = :ability_stddev` for request `$k`.")
75+
elseif type == :ability_distribution
76+
estimator = nothing
77+
integrator = nothing
78+
mean_ability = _find_mean_ability(rules)
79+
if mean_ability === nothing
80+
ability_variance = _find_ability_variance(rules)
81+
if ability_variance === nothing
82+
error("Cannot find a `MeanAbilityEstimator` or `AbilityVariance` in the rules for request `$k`.")
83+
end
84+
estimator = ability_variance.dist_est
85+
integrator = ability_variance.integrator
86+
else
87+
estimator = distribution_estimator(mean_ability)
88+
integrator = mean_ability.integrator
89+
end
90+
new_v[:estimator] = estimator
91+
if !haskey(new_v, :integrator)
92+
new_v[:integrator] = integrator
93+
end
94+
if !haskey(new_v, :points)
95+
integrator = get_integrator(new_v[:integrator])
96+
if !(integrator isa AnyGridIntegrator)
97+
error("Must provide `points` for request `$k` when `integrator` is not an `AnyGridIntegrator`.")
98+
end
99+
new_v[:points] = get_grid(integrator)
100+
end
101+
end
102+
end
103+
end
104+
requests[k] = NamedTuple(new_v)
105+
end
106+
return requests
107+
end
108+
109+
"""
110+
```julia
111+
RecordedCatLoop(;
112+
rules::CatRules,
113+
item_bank::AbstractItemBank = nothing,
114+
responses::Union{Nothing, Vector{ResponseType}} = nothing,
115+
dims::Union{Nothing, Tuple{Int, Int}} = nothing,
116+
expected_responses::Int = 0,
117+
get_response::Function = nothing,
118+
new_response_callback::Function = nothing,
119+
new_response_callbacks::Vector{Function} = Any[]
120+
requests...
121+
)
122+
```
123+
124+
This `RecordedCatLoop` is a simplified construction of a `[CatRules](@ref)`-based `[CatLoop](@ref)` and `[CatRecorder](@ref)`.
125+
126+
It can be constructed with just some cat `rules`, an `item_bank`, and a response memory `responses`, as well as usually one or more `requests` for the `[CatRecorder](@ref).
127+
In this case `dims` are provided by the `item_bank`, and `expected_responses` is set to the length of `responses` as well as used to provide responses using `get_responses`, otherwise the respective arguments must be provided.
128+
The arguments `get_response`, `new_response_callback`, and `new_response_callbacks` are passed to the underlying `CatLoop`.
129+
130+
The resulting `RecordedCatLoop` can be run directly with run_cat.
131+
"""
132+
function RecordedCatLoop(; kwargs...)
133+
kwargs = Dict(kwargs)
134+
responses, get_response = _prepare_get_response!(kwargs)
135+
local expected_responses, rules
136+
if responses !== nothing
137+
expected_responses = length(responses)
138+
else
139+
expected_responses = pop!(kwargs, :expected_responses, 0)
140+
end
141+
if haskey(kwargs, :rules)
142+
rules = pop!(kwargs, :rules)
143+
else
144+
error("Must provide `rules`.")
145+
end
146+
new_response_callback = pop!(kwargs, :new_response_callback, nothing)
147+
new_response_callbacks = pop!(kwargs, :new_response_callbacks, Any[])
148+
local dims
149+
item_bank = nothing
150+
if !haskey(kwargs, :item_bank) && !haskey(kwargs, :dims)
151+
error("Must provide either `item_bank` or `dims`.")
152+
end
153+
if haskey(kwargs, :item_bank)
154+
item_bank = pop!(kwargs, :item_bank)
155+
dims = domdims(item_bank)
156+
end
157+
if haskey(kwargs, :dims)
158+
dims = pop!(kwargs, :dims)
159+
end
160+
requests = enrich_recorder_requests(kwargs, rules)
161+
cat_recorder = CatRecorder(dims, expected_responses; requests...)
162+
RecordedCatLoop(
163+
CatLoop(;
164+
rules,
165+
get_response,
166+
new_response_callback,
167+
new_response_callbacks,
168+
recorder=cat_recorder
169+
),
170+
cat_recorder,
171+
item_bank
172+
)
173+
end
174+
175+
"""
176+
$TYPEDSIGNATURES
177+
178+
Run a given [RecordedCatLoop](@ref) by delegating the call to the wrapped [CatLoop](@ref).
179+
180+
In case `item_bank` is not provided, the item bank provided during the construction of `RecordedCatLoop` is used.
181+
"""
182+
function run_cat(loop::RecordedCatLoop,
183+
item_bank::AbstractItemBank;
184+
ib_labels = nothing)
185+
run_cat(loop.cat_loop, item_bank; ib_labels=ib_labels)
186+
end
187+
188+
function run_cat(loop::RecordedCatLoop; ib_labels = nothing)
189+
if loop.item_bank === nothing
190+
error("Trying to run a RecordedCatLoop without an item bank when no item bank was provided at construction time.")
191+
end
192+
run_cat(loop, loop.item_bank; ib_labels=ib_labels)
193+
end

src/Sim/recorder.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,24 @@ function name_to_label(name)
255255
titlecase(join(split(String(name), "_"), " "))
256256
end
257257

258+
function hasallkeys(haystack, needles...)
259+
return all(n in keys(haystack) for n in needles)
260+
end
261+
258262
function CatRecorder(dims::Int, expected_responses::Int; requests...)
259263
out = []
260264
sizehint!(out, length(requests))
261265
for (name, request) in pairs(requests)
262266
extra = (;)
267+
if !haskey(request, :type)
268+
error("Must provide `type` for $name.")
269+
end
263270
if request.type in (:ability, :ability_stddev)
264271
data = empty_capacity(Float64, expected_responses)
265272
elseif request.type == :ability_distribution
273+
if !hasallkeys(request, :points, :estimator, :integrator)
274+
error("Must provide `points`, `estimator`, and `integrator` for $name.")
275+
end
266276
if dims == 0
267277
data = empty_capacity(Float64, length(request.points), expected_responses)
268278
else

0 commit comments

Comments
 (0)