|
| 1 | +mutable struct SwitchRegenerateState{T} |
| 2 | + weight::Float64 |
| 3 | + score::Float64 |
| 4 | + noise::Float64 |
| 5 | + prev_trace::Trace |
| 6 | + trace::Trace |
| 7 | + index::Int |
| 8 | + retdiff::Diff |
| 9 | + SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) |
| 10 | +end |
| 11 | + |
| 12 | +function process!(gen_fn::Switch{C, N, K, T}, |
| 13 | + index::Int, |
| 14 | + index_argdiff::Diff, |
| 15 | + args::Tuple, |
| 16 | + kernel_argdiffs::Tuple, |
| 17 | + selection::Selection, |
| 18 | + state::SwitchRegenerateState{T}) where {C, N, K, T} |
| 19 | + branch_fn = getfield(gen_fn.branches, index) |
| 20 | + merged = get_selected(get_choices(state.prev_trace), complement(selection)) |
| 21 | + new_trace, weight = generate(branch_fn, args, merged) |
| 22 | + retdiff = UnknownChange() |
| 23 | + weight -= project(state.prev_trace, complement(selection)) |
| 24 | + weight += (project(new_trace, selection) - project(state.prev_trace, selection)) |
| 25 | + state.index = index |
| 26 | + state.weight = weight |
| 27 | + state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) |
| 28 | + state.score = get_score(new_trace) |
| 29 | + state.trace = new_trace |
| 30 | + state.retdiff = retdiff |
| 31 | +end |
| 32 | + |
| 33 | +function process!(gen_fn::Switch{C, N, K, T}, |
| 34 | + index::Int, |
| 35 | + index_argdiff::NoChange, |
| 36 | + args::Tuple, |
| 37 | + kernel_argdiffs::Tuple, |
| 38 | + selection::Selection, |
| 39 | + state::SwitchRegenerateState{T}) where {C, N, K, T} |
| 40 | + new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) |
| 41 | + state.index = index |
| 42 | + state.weight = weight |
| 43 | + state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) |
| 44 | + state.score = get_score(new_trace) |
| 45 | + state.trace = new_trace |
| 46 | + state.retdiff = retdiff |
| 47 | +end |
| 48 | + |
| 49 | +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state) |
| 50 | + |
| 51 | +function regenerate(trace::SwitchTrace{T}, |
| 52 | + args::Tuple, |
| 53 | + argdiffs::Tuple, |
| 54 | + selection::Selection) where T |
| 55 | + gen_fn = trace.gen_fn |
| 56 | + index, index_argdiff = args[1], argdiffs[1] |
| 57 | + state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace) |
| 58 | + process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], selection, state) |
| 59 | + return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.retdiff |
| 60 | +end |
0 commit comments