Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Expand Down
11 changes: 11 additions & 0 deletions docs/src/ref/extending.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ If your generative function has trainable parameters, then implement:

- [`accumulate_param_gradients!`](@ref)

#### Supporting trace serialization
To support trace serialization, a trace type of type `T` for a generative function of type `G` must convertable into a `SerializableTrace` object, and must be recoverable from a `SerializableTrace` object and the generative function.
```@docs
SerializableTrace
to_serializable_trace
from_serializable_trace
```
A user must implement `to_serializable_Trace(::T)`, and `from_serializable_Trace(::ST, ::G)` for some concrete type `ST <: SerializableTrace`. This may be a custom type, or the user may use the built-in type
```@docs
GenericSerializableTrace
```

## Custom modeling languages

Expand Down
11 changes: 11 additions & 0 deletions docs/src/ref/gfi.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,17 @@ The set of elements (either arguments, random choices, or trainable parameters)
If the return value of the function is conditionally dependent on any element in the gradient source set given the arguments and values of all other random choices, for all possible traces of the function, then the generative function requires a *return value gradient* to compute gradients with respect to elements of the gradient source set.
This static property of the generative function is reported by [`accepts_output_grad`](@ref).

## Serialization
To serialize a trace `tr` for a generative function `gf`
(stave the trace to disk), a user may call
```julia
serialize_trace(filename_or_io::Union{IO, AbstractString}, tr)
```
To recover the trace, a user may call
```julia
deserialized_tr = deserialize_trace(filename_or_io, gf)
```

## Generative function interface

The complete set of methods in the generative function interface (GFI) is:
Expand Down
3 changes: 3 additions & 0 deletions src/Gen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ include("trie.jl")
# generative function interface
include("gen_fn_interface.jl")

# serialization/deserialization for traces
include("serialization.jl")

# built-in data types for arg-diff and ret-diff values
include("diff.jl")

Expand Down
1 change: 1 addition & 0 deletions src/dynamic/dynamic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ function gen_fn_changed_error(addr)
error("Generative function changed at address: $addr")
end

include("serialization.jl")
include("simulate.jl")
include("generate.jl")
include("propose.jl")
Expand Down
48 changes: 48 additions & 0 deletions src/dynamic/serialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
function _record_to_serializable(r::ChoiceOrCallRecord{T}) where {T <: Trace}
@assert !r.is_choice
return ChoiceOrCallRecord(to_serializable_trace(r.subtrace_or_retval), r.score, r.noise, r.is_choice)
end
function _record_to_serializable(r::ChoiceOrCallRecord)
@assert r.is_choice
return r
end
function _record_from_serializable(r::ChoiceOrCallRecord{T}, gf::GenerativeFunction) where {T <: SerializableTrace}
@assert !r.is_choice
return ChoiceOrCallRecord(from_serializable_trace(r.subtrace_or_retval, gf), r.score, r.noise, r.is_choice)
end
function _record_from_serializable(r::ChoiceOrCallRecord, dist::Distribution)
@assert r.is_choice
return r
end
function _trie_to_serializable(trie::Trie)
triemap(trie, identity, _record_to_serializable)
end
function to_serializable_trace(tr::DynamicDSLTrace)
return GenericSerializableTrace(
_trie_to_serializable(tr.trie),
(tr.isempty, tr.score, tr.noise, tr.args, tr.retval)
)
end

# since a Dynamic Gen Function doesn't store
# what sub-generative-function is at which address,
# we have to run the generative function to get access to this!
mutable struct GFDeserializeState
trace::DynamicDSLTrace
serialized::GenericSerializableTrace
end
function from_serializable_trace(st::GenericSerializableTrace, gen_fn::DynamicDSLFunction{T}) where T
trace = DynamicDSLTrace{T}(gen_fn, Trie{Any, ChoiceOrCallRecord}(), st.properties...)
state = GFDeserializeState(trace, st)
exec(gen_fn, state, trace.args)
return trace
end
function traceat(state::GFDeserializeState, dist_or_gen_fn, args, key)
record = _record_from_serializable(state.serialized.subtraces[key], dist_or_gen_fn)
state.trace.trie[key] = record
return record.is_choice ? record.subtrace_or_retval : get_retval(record.subtrace_or_retval)
end
function splice(state::GFDeserializeState, gf::DynamicDSLFunction, args::Tuple)
return exec(gf, state, args)
end
read_param(::GFDeserializeState, ::Symbol) = nothing
3 changes: 3 additions & 0 deletions src/dynamic/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ mutable struct DynamicDSLTrace{T} <: Trace
# retval is not known yet
new(gen_fn, trie, true, 0, 0, args)
end
function DynamicDSLTrace{T}(gen_fn::T, trie, isempty, score, noise, args, retval) where {T}
new(gen_fn, trie, isempty, score, noise, args, retval)
end
end

set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval)
Expand Down
7 changes: 7 additions & 0 deletions src/modeling_library/call_at/call_at.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,11 @@ function accumulate_param_gradients!(trace::CallAtTrace, retval_grad)
(kernel_input_grads..., nothing)
end

function to_serializable_trace(tr::CallAtTrace)
return GenericSerializableTrace(to_serializable_trace(tr.subtrace), tr.key)
end
function from_serializable_trace(st::GenericSerializableTrace, gf::CallAtCombinator)
return get_trace_type(gf)(gf, from_serializable_trace(st.subtraces, gf.kernel), st.properties)
end

export call_at
8 changes: 8 additions & 0 deletions src/modeling_library/choice_at/choice_at.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap}
end
get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr)
has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr)
has_value(choices::ChoiceAtChoiceMap, addr) = addr == choices.key
function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K}
choices.key == addr ? choices.value : throw(KeyError(choices, addr))
end
Expand Down Expand Up @@ -172,4 +173,11 @@ function accumulate_param_gradients!(trace::ChoiceAtTrace, retval_grad)
(kernel_arg_grads[2:end]..., nothing)
end

function to_serializable_trace(tr::ChoiceAtTrace)
return GenericSerializableTrace(nothing, (tr.value, tr.key, tr.kernel_args, tr.score))
end
function from_serializable_trace(st::GenericSerializableTrace, gf::ChoiceAtCombinator)
return get_trace_type(gf)(gf, st.properties...)
end

export choice_at
7 changes: 7 additions & 0 deletions src/modeling_library/custom_determ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,11 @@ has_argument_grads(gen_fn::CustomUpdateGF) = tuple(fill(nothing, num_args(gen_fn

apply_with_state(gen_fn::CustomUpdateGF, args) = error("not implemented")

function to_serializable_trace(tr::CustomDetermGFTrace)
return GenericSerializableTrace(nothing, (tr.retval, tr.state, tr.args))
end
function from_serializable_trace(st::GenericSerializableTrace, gf::CustomDetermGF)
return get_trace_type(gf)(st.properties..., gf)
end

export CustomUpdateGF, num_args
2 changes: 2 additions & 0 deletions src/modeling_library/map/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ function get_prev_and_new_lengths(args::Tuple, prev_trace)
(new_length, prev_length)
end

_gen_fn_at_addr(gf::Map, _) = gf.kernel

include("assess.jl")
include("propose.jl")
include("simulate.jl")
Expand Down
21 changes: 21 additions & 0 deletions src/modeling_library/recurse/recurse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,27 @@ function get_aggregation_constraints(constraints::ChoiceMap, cur::Int)
get_submap(constraints, (cur, Val(:aggregation)))
end

function to_serializable_trace(tr::RecurseTrace)
return GenericSerializableTrace(
(
Dict(k => to_serializable_trace(subtr) for (k, subtr) in tr.production_traces),
Dict(k => to_serializable_trace(subtr) for (k, subtr) in tr.aggregation_traces)
),
(tr.max_branch, tr.score, tr.root_idx, tr.num_has_choices)
)
end
function from_serializable_trace(st::GenericSerializableTrace, gf::Recurse{S, T}) where {S, T}
production_traces = PersistentHashMap{Int, S}()
for (k, subst) in st.subtraces[1]
production_traces = assoc(production_traces, k, from_serializable_trace(subst, gf.production_kern))
end
aggregation_traces = PersistentHashMap{Int, T}()
for (k, subst) in st.subtraces[2]
aggregation_traces = assoc(aggregation_traces, k, from_serializable_trace(subst, gf.aggregation_kern))
end
return get_trace_type(gf)(gf, production_traces, aggregation_traces, st.properties...)
end

############
# simulate #
############
Expand Down
13 changes: 13 additions & 0 deletions src/modeling_library/switch/switch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ function (gen_fn::Switch{C})(index::C, args...) where C
end

include("trace.jl")

function to_serializable_trace(tr::SwitchTrace)
GenericSerializableTrace(to_serializable_trace(tr.branch), (tr.index, tr.retval, tr.args, tr.score, tr.noise))
end
function from_serializable_trace(c::GenericSerializableTrace, gf::Switch)
(index, retval, args, score, noise) = c.properties
GenericSerializableTrace(
gf, index,
from_serializable_trace(c.subtraces, gf.branches[index]),
retval, args, score, noise
)
end

include("assess.jl")
include("propose.jl")
include("simulate.jl")
Expand Down
2 changes: 2 additions & 0 deletions src/modeling_library/unfold/unfold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ function check_length(len::Int)
end
end

_gen_fn_at_addr(gf::Unfold, _) = gf.kernel

include("simulate.jl")
include("generate.jl")
include("propose.jl")
Expand Down
17 changes: 17 additions & 0 deletions src/modeling_library/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,20 @@ function vector_remove_deleted_applications(subtraces, retval, prev_length, new_
end
(subtraces, retval)
end

#################
# Serialization #
#################
function to_serializable_trace(trace::VectorTrace)
GenericSerializableTrace(
[to_serializable_trace(st) for st in trace.subtraces],
(trace.retval, trace.args, trace.len, trace.num_nonempty, trace.score, trace.noise)
)
end
function from_serializable_trace(st::GenericSerializableTrace, gf::GenerativeFunction{<:Any, VectorTrace{GenFnType, T, U}}) where {GenFnType, T, U}
subtraces = PersistentVector{U}(
[from_serializable_trace(serialized_subtrace, _gen_fn_at_addr(gf, i))
for (i, serialized_subtrace) in enumerate(st.subtraces)]
)
get_trace_type(gf)(gf, subtraces, st.properties...)
end
62 changes: 62 additions & 0 deletions src/serialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using Serialization: serialize, deserialize

"""
SerializableTrace

A representation of a `Trace` which can be serialized. Obtainable via `to_serializable_trace`.
This does not need to contain the `GenerativeFunction` which produced the trace;
to deserialize (using `from_serializable_trace`), the `GenerativeFunction` must be provided.
"""
abstract type SerializableTrace end

"""
to_serializable_trace(trace::Trace)

Get a SerializableTrace representing the `trace` in a serializable manner.
"""
function to_serializable_trace(trace::Trace)
error("Not implemented")
end

"""
from_serializable_trace(st::SerializableTrace, fn::GenerativeFunction)

Get the trace of the given generative function encoded by the serializable trace object.
"""
function from_serializable_trace(::SerializableTrace, ::GenerativeFunction)
error("Not implemented.")
end

"""
serialize_trace(stream::IO, trace::Trace)
serialize_trace(filename::AbstractString, trace::Trace)

Serialize the given trace to the given stream or file, by converting to a `SerializableTrace`.
"""
function serialize_trace(filename_or_io::Union{IO, AbstractString}, trace::Trace)
return serialize(filename_or_io, to_serializable_trace(trace))
end

"""
deserialize_trace(stream::IO, gen_fn::GenerativeFunction)
deserialize_trace(filename::AbstractString, gen_fn::GenerativeFunction)

Deserialize the trace for the given generative function stored in the given stream or file
(as saved via `serialize_trace`).
"""
function deserialize_trace(filename_or_io::Union{IO, AbstractString}, gf::GenerativeFunction)
return from_serializable_trace(deserialize(filename_or_io), gf)
end

"""
GenericSerializableTrace <: SerializableTrace

A SerializableTrace which contains some subtraces which have been recursively converted
to `SerializableTrace`s, and some properties which are directly serializable.
"""
struct GenericSerializableTrace{S, P} <: SerializableTrace
subtraces::S
properties::P
end

export to_serializable_trace, from_serializable_trace, serialize_trace, deserialize_trace
9 changes: 6 additions & 3 deletions src/static_ir/static_ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ function generate_generative_function(ir::StaticIR, name::Symbol; track_diffs=fa
end

function generate_generative_function(ir::StaticIR, name::Symbol, options::StaticIRGenerativeFunctionOptions)
gen_fn_type_name = gensym("StaticGenFunction_$name")

(trace_defns, trace_struct_name) = generate_trace_type_and_methods(ir, name, options)
(trace_defns, trace_struct_name, tracefields) = generate_trace_type_and_methods(ir, name, options)

gen_fn_type_name = gensym("StaticGenFunction_$name")
return_type = ir.return_node.typ
trace_type = trace_struct_name
has_argument_grads = tuple(map((node) -> node.compute_grad, ir.arg_nodes)...)
Expand All @@ -61,7 +61,10 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati
$(GlobalRef(Gen, :get_gen_fn_type))(::Type{$trace_struct_name}) = $gen_fn_type_name
$(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options))
end
Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}())))

serialization_code = generate_serialization_methods(ir, trace_struct_name, gen_fn_type_name, tracefields)

Expr(:block, trace_defns, gen_fn_defn, serialization_code, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}())))
end

include("print_ir.jl")
Expand Down
Loading