|
| 1 | +struct RecordedCatLoop |
| 2 | + cat_loop::CatLoop{<: CatRules} |
| 3 | + recorder::CatRecorder |
| 4 | + item_bank::Union{AbstractItemBank, Nothing} |
| 5 | +end |
| 6 | + |
| 7 | +function _prepare_get_response!(kwargs) |
| 8 | + has_responses = haskey(kwargs, :responses) |
| 9 | + has_get_response = haskey(kwargs, :get_response) |
| 10 | + if has_responses && has_get_response |
| 11 | + error("Cannot provide both `responses` and `get_response`.") |
| 12 | + elseif !has_responses && !has_get_response |
| 13 | + error("Must provide either `responses` or `get_response`.") |
| 14 | + elseif has_get_response |
| 15 | + return nothing, pop!(kwargs, :get_response) |
| 16 | + else |
| 17 | + responses = pop!(kwargs, :responses) |
| 18 | + return responses, Sim.auto_responder(responses) |
| 19 | + end |
| 20 | +end |
| 21 | + |
| 22 | +function _walk_find_type(obj, typ, out=[]) |
| 23 | + if obj isa typ |
| 24 | + push!(out, obj) |
| 25 | + end |
| 26 | + for fieldname in propertynames(obj) |
| 27 | + _walk_find_type(getfield(obj, fieldname), typ, out) |
| 28 | + end |
| 29 | + return out |
| 30 | +end |
| 31 | + |
| 32 | +function _find_mean_ability(rules) |
| 33 | + if rules.ability_estimator isa MeanAbilityEstimator |
| 34 | + return rules.ability_estimator |
| 35 | + end |
| 36 | + result = _walk_find_type(rules.next_item, MeanAbilityEstimator) |
| 37 | + if !isempty(result) |
| 38 | + return result[1] |
| 39 | + end |
| 40 | + result = _walk_find_type(rules.termination_condition, MeanAbilityEstimator) |
| 41 | + if !isempty(result) |
| 42 | + return result[1] |
| 43 | + end |
| 44 | + return nothing |
| 45 | +end |
| 46 | + |
| 47 | +function _find_ability_variance(rules) |
| 48 | + result = _walk_find_type(rules.next_item, AbilityVariance) |
| 49 | + if !isempty(result) |
| 50 | + return result[1] |
| 51 | + end |
| 52 | + return nothing |
| 53 | +end |
| 54 | + |
| 55 | +function enrich_recorder_requests(old_requests, rules) |
| 56 | + requests = Dict() |
| 57 | + for (k, v) in pairs(old_requests) |
| 58 | + new_v = Dict{Symbol, Any}(pairs(v)) |
| 59 | + type = get(new_v, :type, nothing) |
| 60 | + if type in (:ability, :ability_distribution, :ability_stddev) |
| 61 | + if haskey(new_v, :estimator) && haskey(new_v, :source) |
| 62 | + error("Cannot provide both `estimator` and `source` for request `$k`.") |
| 63 | + elseif !haskey(new_v, :estimator) |
| 64 | + if !haskey(new_v, :source) |
| 65 | + error("Must provide either `estimator` or `source` for request `$k`.") |
| 66 | + end |
| 67 | + source = new_v[:source] |
| 68 | + if source != :any |
| 69 | + error("Not implemented yet: `source = $source` for request `$k`.") |
| 70 | + end |
| 71 | + if type == :ability |
| 72 | + new_v[:estimator] = rules.ability_estimator |
| 73 | + elseif type == :ability_stddev |
| 74 | + error("Not implemented yet: `type = :ability_stddev` for request `$k`.") |
| 75 | + elseif type == :ability_distribution |
| 76 | + estimator = nothing |
| 77 | + integrator = nothing |
| 78 | + mean_ability = _find_mean_ability(rules) |
| 79 | + if mean_ability === nothing |
| 80 | + ability_variance = _find_ability_variance(rules) |
| 81 | + if ability_variance === nothing |
| 82 | + error("Cannot find a `MeanAbilityEstimator` or `AbilityVariance` in the rules for request `$k`.") |
| 83 | + end |
| 84 | + estimator = ability_variance.dist_est |
| 85 | + integrator = ability_variance.integrator |
| 86 | + else |
| 87 | + estimator = distribution_estimator(mean_ability) |
| 88 | + integrator = mean_ability.integrator |
| 89 | + end |
| 90 | + new_v[:estimator] = estimator |
| 91 | + if !haskey(new_v, :integrator) |
| 92 | + new_v[:integrator] = integrator |
| 93 | + end |
| 94 | + if !haskey(new_v, :points) |
| 95 | + integrator = get_integrator(new_v[:integrator]) |
| 96 | + if !(integrator isa AnyGridIntegrator) |
| 97 | + error("Must provide `points` for request `$k` when `integrator` is not an `AnyGridIntegrator`.") |
| 98 | + end |
| 99 | + new_v[:points] = get_grid(integrator) |
| 100 | + end |
| 101 | + end |
| 102 | + end |
| 103 | + end |
| 104 | + requests[k] = NamedTuple(new_v) |
| 105 | + end |
| 106 | + return requests |
| 107 | +end |
| 108 | + |
| 109 | +""" |
| 110 | +```julia |
| 111 | +RecordedCatLoop(; |
| 112 | + rules::CatRules, |
| 113 | + item_bank::AbstractItemBank = nothing, |
| 114 | + responses::Union{Nothing, Vector{ResponseType}} = nothing, |
| 115 | + dims::Union{Nothing, Tuple{Int, Int}} = nothing, |
| 116 | + expected_responses::Int = 0, |
| 117 | + get_response::Function = nothing, |
| 118 | + new_response_callback::Function = nothing, |
| 119 | + new_response_callbacks::Vector{Function} = Any[] |
| 120 | + requests... |
| 121 | +) |
| 122 | +``` |
| 123 | +
|
| 124 | +This `RecordedCatLoop` is a simplified construction of a `[CatRules](@ref)`-based `[CatLoop](@ref)` and `[CatRecorder](@ref)`. |
| 125 | +
|
| 126 | +It can be constructed with just some cat `rules`, an `item_bank`, and a response memory `responses`, as well as usually one or more `requests` for the `[CatRecorder](@ref). |
| 127 | +In this case `dims` are provided by the `item_bank`, and `expected_responses` is set to the length of `responses` as well as used to provide responses using `get_responses`, otherwise the respective arguments must be provided. |
| 128 | +The arguments `get_response`, `new_response_callback`, and `new_response_callbacks` are passed to the underlying `CatLoop`. |
| 129 | +
|
| 130 | +The resulting `RecordedCatLoop` can be run directly with run_cat. |
| 131 | +""" |
| 132 | +function RecordedCatLoop(; kwargs...) |
| 133 | + kwargs = Dict(kwargs) |
| 134 | + responses, get_response = _prepare_get_response!(kwargs) |
| 135 | + local expected_responses, rules |
| 136 | + if responses !== nothing |
| 137 | + expected_responses = length(responses) |
| 138 | + else |
| 139 | + expected_responses = pop!(kwargs, :expected_responses, 0) |
| 140 | + end |
| 141 | + if haskey(kwargs, :rules) |
| 142 | + rules = pop!(kwargs, :rules) |
| 143 | + else |
| 144 | + error("Must provide `rules`.") |
| 145 | + end |
| 146 | + new_response_callback = pop!(kwargs, :new_response_callback, nothing) |
| 147 | + new_response_callbacks = pop!(kwargs, :new_response_callbacks, Any[]) |
| 148 | + local dims |
| 149 | + item_bank = nothing |
| 150 | + if !haskey(kwargs, :item_bank) && !haskey(kwargs, :dims) |
| 151 | + error("Must provide either `item_bank` or `dims`.") |
| 152 | + end |
| 153 | + if haskey(kwargs, :item_bank) |
| 154 | + item_bank = pop!(kwargs, :item_bank) |
| 155 | + dims = domdims(item_bank) |
| 156 | + end |
| 157 | + if haskey(kwargs, :dims) |
| 158 | + dims = pop!(kwargs, :dims) |
| 159 | + end |
| 160 | + requests = enrich_recorder_requests(kwargs, rules) |
| 161 | + cat_recorder = CatRecorder(dims, expected_responses; requests...) |
| 162 | + RecordedCatLoop( |
| 163 | + CatLoop(; |
| 164 | + rules, |
| 165 | + get_response, |
| 166 | + new_response_callback, |
| 167 | + new_response_callbacks, |
| 168 | + recorder=cat_recorder |
| 169 | + ), |
| 170 | + cat_recorder, |
| 171 | + item_bank |
| 172 | + ) |
| 173 | +end |
| 174 | + |
| 175 | +""" |
| 176 | +$TYPEDSIGNATURES |
| 177 | +
|
| 178 | +Run a given [RecordedCatLoop](@ref) by delegating the call to the wrapped [CatLoop](@ref). |
| 179 | +
|
| 180 | +In case `item_bank` is not provided, the item bank provided during the construction of `RecordedCatLoop` is used. |
| 181 | +""" |
| 182 | +function run_cat(loop::RecordedCatLoop, |
| 183 | + item_bank::AbstractItemBank; |
| 184 | + ib_labels = nothing) |
| 185 | + run_cat(loop.cat_loop, item_bank; ib_labels=ib_labels) |
| 186 | +end |
| 187 | + |
| 188 | +function run_cat(loop::RecordedCatLoop; ib_labels = nothing) |
| 189 | + if loop.item_bank === nothing |
| 190 | + error("Trying to run a RecordedCatLoop without an item bank when no item bank was provided at construction time.") |
| 191 | + end |
| 192 | + run_cat(loop, loop.item_bank; ib_labels=ib_labels) |
| 193 | +end |
0 commit comments