diff --git a/README.md b/README.md index 8ab89b0e1..eb501f96d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Gen.jl -[![Build Status](https://travis-ci.com/probcomp/Gen.jl.svg?branch=master)](https://travis-ci.com/probcomp/Gen.jl) + +[![Build Status](https://travis-ci.com/probcomp/Gen.jl.svg?branch=master)](https://app.travis-ci.com/github/probcomp/Gen.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://probcomp.github.io/Gen.jl/stable) [![](https://img.shields.io/badge/docs-dev-blue.svg)](https://probcomp.github.io/Gen.jl/dev) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index c065b1b32..4a23b7cfa 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -8,13 +8,58 @@ ChoiceMap Choice maps are constructed by users to express observations and/or constraints on the traces of generative functions. Choice maps are also returned by certain Gen inference methods, and are used internally by various Gen inference methods. +A choicemap a tree, whose leaf nodes store a single value, and whose internal nodes provide addresses +for sub-choicemaps. Leaf nodes have type: +```@docs +ValueChoiceMap +``` + +### Example Usage Overview + +Choicemaps store values nested in a tree where each node posesses an address for each subtree. +A leaf-node choicemap simply contains a value, and has its value looked up via: +```julia +value = choicemap[] +``` +If a choicemap has a value choicemap at address `:a`, the value it stores is looked up via: +```julia +value = choicemap[:a] +``` +A choicemap may also have a non-value choicemap stored at an address. For instance, +if a choicemap has another choicemap stored at address `:a`, and this internal choicemap +has a valuechoicemap stored at address `:b` and another at `:c`, we could perform the following lookups: +```julia +value1 = choicemap[:a => :b] +value2 = choicemap[:a => :c] +``` +Nesting can be arbitrarily deep, and the keys can be arbitrary values; for instance +choicemaps can be constructed with values at the following nested addresses: +```julia +value = choicemap[:a => :b => :c => 4 => 1.63 => :e] +value = choicemap[:a => :b => :a => 2 => "alphabet" => :e] +``` +To get a sub-choicemap, use `get_submap`: +```julia +value1 = choicemap[:a => :b] +submap = get_submap(choicemap, :a) +value1 == submap[:b] # is true + +value_submap = get_submap(choicemap, :a => :b) +value_submap[] == value1 # is true +``` +One can think of `ValueChoiceMap`s at storing being a choicemap which has a value at "nesting level zero", +while other choicemaps have values at "nesting level" one or higher. + +### Interface + Choice maps provide the following methods: ```@docs +get_submap +get_submaps_shallow has_value get_value -get_submap get_values_shallow -get_submaps_shallow +get_nonvalue_submaps_shallow to_array from_array get_selected @@ -23,7 +68,7 @@ Note that none of these methods mutate the choice map. Choice maps also implement: -- `Base.isempty`, which tests of there are no random choices in the choice map +- `Base.isempty`, which returns `false` if the choicemap contains no value or submaps, and `true` otherwise. - `Base.merge`, which takes two choice maps, and returns a new choice map containing all random choices in either choice map. It is an error if the choice maps both have values at the same address, or if one choice map has a value at an address that is the prefix of the address of a value in the other choice map. @@ -50,3 +95,35 @@ choicemap set_value! set_submap! ``` + +### Implementing custom choicemap types + +To implement a custom choicemap, one must implement +`get_submap` and `get_submaps_shallow`. +To avoid method ambiguity with the default +`get_submap(::ChoiceMap, ::Pair)`, one must implement both +```julia +get_submap(::CustomChoiceMap, addr) +``` +and +```julia +get_submap(::CustomChoiceMap, addr::Pair) +``` +To use the default implementation of `get_submap(_, ::Pair)`, +one may define +```julia +get_submap(c::CustomChoiceMap, addr::Pair) = _get_choicemap(c, addr) +``` + +Once `get_submap` and `get_submaps_shallow` are defined, default +implementations are provided for: +- `has_value` +- `get_value` +- `get_values_shallow` +- `get_nonvalue_submaps_shallow` +- `to_array` +- `get_selected` + +If one wishes to support `from_array`, they must implement +`_from_array`, as described in the documentation for +[`from_array`](@ref). \ No newline at end of file diff --git a/docs/src/ref/distributions.md b/docs/src/ref/distributions.md index 134c84e8c..b6ad1a81f 100644 --- a/docs/src/ref/distributions.md +++ b/docs/src/ref/distributions.md @@ -1,7 +1,13 @@ # Probability Distributions -Gen provides a library of built-in probability distributions, and three ways of -defining custom distributions, each of which are explained below: +In Gen, a probability distribution is a generative function which makes a single random choice +and returns the value of this choice. The choicemap for a probability distribution +is always a [`ValueChoiceMap`](@ref). In addition to supporting the regular `GFI` methods, +every distribution supports the methods [`random`](@ref) and [`logpdf`](@ref), described +in the [Distribution API](@ref custom_distributions). + +Gen provides a library of built-in probability distributions, and two ways of +writing custom distributions, both of which are explained below: 1. The [`@dist` constructor](@ref dist_dsl), for a distribution that can be expressed as a simple deterministic transformation (technically, a diff --git a/docs/src/ref/extending.md b/docs/src/ref/extending.md index 7f9dfd480..b1d759b3a 100644 --- a/docs/src/ref/extending.md +++ b/docs/src/ref/extending.md @@ -110,7 +110,7 @@ Gen's Distribution interface directly, as defined below. Probability distributions are singleton types whose supertype is `Distribution{T}`, where `T` indicates the data type of the random sample. ```julia -abstract type Distribution{T} end +abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace} end ``` A new Distribution type must implement the following methods: @@ -146,6 +146,9 @@ has_output_grad logpdf_grad ``` +Any custom distribution will automatically be a `GenerativeFunction` since `Distribution <: GenerativeFunction`; +implementations of all GFI methods are automatically provided in terms of `random` and `logpdf`. + ## Custom generative functions We recommend the following steps for implementing a new type of generative function, and also looking at the implementation for the [`DynamicDSLFunction`](@ref) type as an example. diff --git a/src/Gen.jl b/src/Gen.jl index 29113ff11..64b3c367d 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -43,18 +43,21 @@ include("backprop.jl") include("address.jl") # abstract and built-in concrete choice map data types -include("choice_map.jl") +include("choice_map/choice_map.jl") # a homogeneous trie data type (not for use as choice map) include("trie.jl") +# built-in data types for arg-diff and ret-diff values +include("diff.jl") + # generative function interface include("gen_fn_interface.jl") -# built-in data types for arg-diff and ret-diff values -include("diff.jl") +# distribution abstract type +include("distribution.jl") -# built-in probability disributions +# built-in probability disributions; distribution dsl; combinators include("modeling_library/modeling_library.jl") # optimization of trainable parameters diff --git a/src/address.jl b/src/address.jl index fe0fb30e9..9d751fb46 100644 --- a/src/address.jl +++ b/src/address.jl @@ -151,6 +151,7 @@ A hierarchical selection whose keys are among its type parameters. struct StaticSelection{T,U} <: HierarchicalSelection subselections::NamedTuple{T,U} end +StaticSelection(::NamedTuple{(), Tuple{}}) = EmptySelection() function Base.isempty(selection::StaticSelection{T,U}) where {T,U} length(R) == 0 && all(isempty(node) for node in selection.subselections) @@ -208,7 +209,7 @@ function StaticSelection(other::HierarchicalSelection) (keys, subselections) = ((), ()) end types = map(typeof, subselections) - StaticSelection{keys,Tuple{types...}}(NamedTuple{keys}(subselections)) + StaticSelection(NamedTuple{keys}(subselections)) end export StaticSelection diff --git a/src/choice_map.jl b/src/choice_map.jl deleted file mode 100644 index fc72524cd..000000000 --- a/src/choice_map.jl +++ /dev/null @@ -1,1038 +0,0 @@ -######################### -# choice map interface # -######################### - -""" - schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} - -Return the (top-level) address schema for the given choice map. -""" -function get_address_schema end - -""" - submap = get_submap(choices::ChoiceMap, addr) - -Return the sub-assignment containing all choices whose address is prefixed by addr. - -It is an error if the assignment contains a value at the given address. If -there are no choices whose address is prefixed by addr then return an -`EmptyChoiceMap`. -""" -function get_submap end - -""" - value = get_value(choices::ChoiceMap, addr) - -Return the value at the given address in the assignment, or throw a KeyError if -no value exists. A syntactic sugar is `Base.getindex`: - - value = choices[addr] -""" -function get_value end - -""" - has_submap(choices::ChoiceMap, addr) - -Return true if there is a non-empty sub-assignment at the given address. -""" -function has_submap end - -""" - key_submap_iterable = get_submaps_shallow(choices::ChoiceMap) - -Return an iterator over tuples of the form `(key, submap::ChoiceMap)` for each top-level key -that has a non-empty sub-assignment. -""" -function get_submaps_shallow end - -""" - has_value(choices::ChoiceMap, addr) - -Return true if there is a value at the given address. -""" -function has_value end - -""" - key_submap_iterable = get_values_shallow(choices::ChoiceMap) - -Return an iterator over tuples of the form `(key, value)` for each -top-level key associated with a value. -""" -function get_values_shallow end - -""" - abstract type ChoiceMap end - -Abstract type for maps from hierarchical addresses to values. -""" -abstract type ChoiceMap end - -""" - Base.isempty(choices::ChoiceMap) - -Return true if there are no values in the assignment. -""" -function Base.isempty(::ChoiceMap) - true -end - - -@inline has_submap(choices::ChoiceMap, addr) = !has_value(choices, addr) && !isempty(get_submap(choices, addr)) -@inline get_submap(choices::ChoiceMap, addr) = EmptyChoiceMap() -@inline has_value(choices::ChoiceMap, addr) = false -@inline get_value(choices::ChoiceMap, addr) = throw(KeyError(addr)) -@inline Base.getindex(choices::ChoiceMap, addr) = get_value(choices, addr) - -@inline function _has_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - has_value(submap, rest) -end - -@inline function _get_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_value(submap, rest) -end - -@inline function _get_submap(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_submap(submap, rest) -end - -function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) - VERT = '\u2502' - PLUS = '\u251C' - HORZ = '\u2500' - LAST = '\u2514' - indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) - indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) - for i in vert_bars - indent_vert[i] = VERT - indent[i] = VERT - indent_last[i] = VERT - end - indent_vert_str = join(indent_vert) - indent_vert_last_str = join(indent_vert_last) - indent_str = join(indent) - indent_last_str = join(indent_last) - key_and_values = collect(get_values_shallow(choices)) - key_and_submaps = collect(get_submaps_shallow(choices)) - n = length(key_and_values) + length(key_and_submaps) - cur = 1 - for (key, value) in key_and_values - # For strings, `print` is what we want; `Base.show` includes quote marks. - # https://docs.julialang.org/en/v1/base/io-network/#Base.print - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") - cur += 1 - end - for (key, submap) in key_and_submaps - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") - _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) - cur += 1 - end -end - -function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) - _show_pretty(io, choices, 0, ()) -end - -# assignments that have static address schemas should also support faster -# accessors, which make the address explicit in the type (Val(:foo) instaed of -# :foo) -function static_get_value end -function static_get_submap end - -function _fill_array! end -function _from_array end - -""" - arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} - -Populate an array with values of choices in the given assignment. - -It is an error if each of the values cannot be coerced into a value of the -given type. - -# Implementation - -To support `to_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Populate `arr` with values from the given assignment, starting at `start_idx`, -and return the number of elements in `arr` that were populated. -""" -function to_array(choices::ChoiceMap, ::Type{T}) where {T} - arr = Vector{T}(undef, 32) - n = _fill_array!(choices, arr, 1) - @assert n <= length(arr) - resize!(arr, n) - arr -end - -function _fill_array!(value::T, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx - resize!(arr, 2 * start_idx) - end - arr[start_idx] = value - 1 -end - -function _fill_array!(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx + length(value) - resize!(arr, 2 * (start_idx + length(value))) - end - arr[start_idx:start_idx+length(value)-1] = value - length(value) -end - - -""" - choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) - -Return an assignment with the same address structure as a prototype -assignment, but with values read off from the given array. - -The order in which addresses are populated is determined by the prototype -assignment. It is an error if the number of choices in the prototype assignment -is not equal to the length the array. - -# Implementation - -To support `from_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - - (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Return an assignment with the same address structure as a prototype assignment, -but with values read off from `arr`, starting at position `start_idx`, and the -number of elements read from `arr`. -""" -function from_array(proto_choices::ChoiceMap, arr::Vector) - (n, choices) = _from_array(proto_choices, arr, 1) - if n != length(arr) - error("Dimension mismatch: $n, $(length(arr))") - end - choices -end - -function _from_array(::T, arr::Vector{T}, start_idx::Int) where {T} - (1, arr[start_idx]) -end - -function _from_array(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - n_read = length(value) - (n_read, arr[start_idx:start_idx+n_read-1]) -end - - -""" - choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - -Merge two choice maps. - -It is an error if the choice maps both have values at the same address, or if -one choice map has a value at an address that is the prefix of the address of a -value in the other choice map. -""" -function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - choices = DynamicChoiceMap() - for (key, value) in get_values_shallow(choices1) - choices.leaf_nodes[key] = value - end - for (key, node1) in get_submaps_shallow(choices1) - node2 = get_submap(choices2, key) - node = merge(node1, node2) - choices.internal_nodes[key] = node - end - for (key, value) in get_values_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has leaf node at $key") - end - if haskey(choices.internal_nodes, key) - error("choices1 has internal node at $key and choices2 has leaf node at $key") - end - choices.leaf_nodes[key] = value - end - for (key, node) in get_submaps_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has internal node at $key") - end - if !haskey(choices.internal_nodes, key) - # otherwise it should already be included - choices.internal_nodes[key] = node - end - end - return choices -end - -""" -Variadic merge of choice maps. -""" -function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) - reduce(Base.merge, choices_rest; init=choices1) -end - -function Base.:(==)(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || (get_value(b, addr) != value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || (get_value(a, addr) != value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if submap != get_submap(b, addr) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if submap != get_submap(a, addr) - return false - end - end - return true -end - -# This is modeled after -# https://github.com/JuliaLang/julia/blob/7bff5cdd0fab8d625e48b3a9bb4e94286f2ba18c/base/abstractdict.jl#L530-L537 -const hasha_seed = UInt === UInt64 ? 0x6d35bb51952d5539 : 0x952d5539 -function Base.hash(a::ChoiceMap, h::UInt) - hv = hasha_seed - for (addr, value) in get_values_shallow(a) - hv = xor(hv, hash(addr, hash(value))) - end - for (addr, submap) in get_submaps_shallow(a) - hv = xor(hv, hash(addr, hash(submap))) - end - - return hash(hv, h) -end - -function Base.isapprox(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || !isapprox(get_value(b, addr), value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || !isapprox(get_value(a, addr), value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if !isapprox(submap, get_submap(b, addr)) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if !isapprox(submap, get_submap(a, addr)) - return false - end - end - return true -end - - -export ChoiceMap -export get_address_schema -export has_submap -export get_submap -export get_value -export has_value -export get_submaps_shallow -export get_values_shallow -export static_get_value -export static_get_submap -export to_array, from_array - - -###################### -# static assignment # -###################### - -struct StaticChoiceMap{R,S,T,U} <: ChoiceMap - leaf_nodes::NamedTuple{R,S} - internal_nodes::NamedTuple{T,U} - isempty::Bool -end - -function StaticChoiceMap{R,S,T,U}(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - -function StaticChoiceMap(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - - -# invariant: all internal_nodes are nonempty - -function get_address_schema(::Type{StaticChoiceMap{R,S,T,U}}) where {R,S,T,U} - keys = Set{Symbol}() - for (key, _) in zip(R, S.parameters) - push!(keys, key) - end - for (key, _) in zip(T, U.parameters) - push!(keys, key) - end - StaticAddressSchema(keys) -end - -function Base.isempty(choices::StaticChoiceMap) - choices.isempty -end - -get_values_shallow(choices::StaticChoiceMap) = pairs(choices.leaf_nodes) -get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.internal_nodes) -has_value(choices::StaticChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::StaticChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) - -# NOTE: there is no static_has_value because this is known from the static -# address schema - -## has_value ## - -function has_value(choices::StaticChoiceMap, key::Symbol) - haskey(choices.leaf_nodes, key) -end - -## get_submap ## - -function get_submap(choices::StaticChoiceMap, key::Symbol) - if haskey(choices.internal_nodes, key) - choices.internal_nodes[key] - elseif haskey(choices.leaf_nodes, key) - throw(KeyError(key)) - else - EmptyChoiceMap() - end -end - -function static_get_submap(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.internal_nodes[A] -end - -## get_value ## - -function get_value(choices::StaticChoiceMap, key::Symbol) - choices.leaf_nodes[key] -end - -function static_get_value(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.leaf_nodes[A] -end - -# convert from any other schema that has only Val{:foo} addresses -function StaticChoiceMap(other::ChoiceMap) - leaf_keys_and_nodes = collect(get_values_shallow(other)) - internal_keys_and_nodes = collect(get_submaps_shallow(other)) - if length(leaf_keys_and_nodes) > 0 - (leaf_keys, leaf_nodes) = collect(zip(leaf_keys_and_nodes...)) - else - (leaf_keys, leaf_nodes) = ((), ()) - end - if length(internal_keys_and_nodes) > 0 - (internal_keys, internal_nodes) = collect(zip(internal_keys_and_nodes...)) - else - (internal_keys, internal_nodes) = ((), ()) - end - StaticChoiceMap( - NamedTuple{leaf_keys}(leaf_nodes), - NamedTuple{internal_keys}(internal_nodes), - isempty(other)) -end - -""" - choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - -Return an assignment that contains `choices1` as a sub-assignment under `key1` -and `choices2` as a sub-assignment under `key2`. -""" -function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - StaticChoiceMap(NamedTuple(), NamedTuple{(key1,key2)}((choices1, choices2)), - isempty(choices1) && isempty(choices2)) -end - -""" - (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - -Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. - -It is an error if there are any top-level values, or any non-empty top-level -sub-assignments at keys other than `key1` and `key2`. -""" -function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - if !isempty(get_values_shallow(choices)) || length(collect(get_submaps_shallow(choices))) > 2 - error("Not a pair") - end - a = get_submap(choices, key1) - b = get_submap(choices, key2) - (a, b) -end - -# TODO make a generated function? -function _fill_array!(choices::StaticChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for value in choices.leaf_nodes - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for node in choices.internal_nodes - n_written = _fill_array!(node, arr, idx) - idx += n_written - end - idx - start_idx -end - -@generated function _from_array( - proto_choices::StaticChoiceMap{R,S,T,U}, arr::Vector{V}, start_idx::Int) where {R,S,T,U,V} - leaf_node_keys = proto_choices.parameters[1] - leaf_node_types = proto_choices.parameters[2].parameters - internal_node_keys = proto_choices.parameters[3] - internal_node_types = proto_choices.parameters[4].parameters - - exprs = [quote idx = start_idx end] - leaf_node_names = [] - internal_node_names = [] - - # leaf nodes - for key in leaf_node_keys - value = gensym() - push!(leaf_node_names, value) - push!(exprs, quote - (n_read, $value) = _from_array(proto_choices.leaf_nodes.$key, arr, idx) - idx += n_read - end) - end - - # internal nodes - for key in internal_node_keys - node = gensym() - push!(internal_node_names, node) - push!(exprs, quote - (n_read, $node) = _from_array(proto_choices.internal_nodes.$key, arr, idx) - idx += n_read - end) - end - - quote - $(exprs...) - leaf_nodes_field = NamedTuple{R,S}(($(leaf_node_names...),)) - internal_nodes_field = NamedTuple{T,U}(($(internal_node_names...),)) - choices = StaticChoiceMap{R,S,T,U}(leaf_nodes_field, internal_nodes_field) - (idx - start_idx, choices) - end -end - -@generated function Base.merge(choices1::StaticChoiceMap{R,S,T,U}, - choices2::StaticChoiceMap{W,X,Y,Z}) where {R,S,T,U,W,X,Y,Z} - - # unpack first assignment type parameters - leaf_node_keys1 = choices1.parameters[1] - leaf_node_types1 = choices1.parameters[2].parameters - internal_node_keys1 = choices1.parameters[3] - internal_node_types1 = choices1.parameters[4].parameters - keys1 = (leaf_node_keys1..., internal_node_keys1...,) - - # unpack second assignment type parameters - leaf_node_keys2 = choices2.parameters[1] - leaf_node_types2 = choices2.parameters[2].parameters - internal_node_keys2 = choices2.parameters[3] - internal_node_types2 = choices2.parameters[4].parameters - keys2 = (leaf_node_keys2..., internal_node_keys2...,) - - # leaf vs leaf collision is an error - colliding_leaf_leaf_keys = intersect(leaf_node_keys1, leaf_node_keys2) - if !isempty(colliding_leaf_leaf_keys) - error("choices1 and choices2 both have leaf nodes at key(s): $colliding_leaf_leaf_keys") - end - - # leaf vs internal collision is an error - colliding_leaf_internal_keys = intersect(leaf_node_keys1, internal_node_keys2) - if !isempty(colliding_leaf_internal_keys) - error("choices1 has leaf node and choices2 has internal node at key(s): $colliding_leaf_internal_keys") - end - - # internal vs leaf collision is an error - colliding_internal_leaf_keys = intersect(internal_node_keys1, leaf_node_keys2) - if !isempty(colliding_internal_leaf_keys) - error("choices1 has internal node and choices2 has leaf node at key(s): $colliding_internal_leaf_keys") - end - - # internal vs internal collision is not an error, recursively call merge - colliding_internal_internal_keys = (intersect(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys1_exclusive = (setdiff(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys2_exclusive = (setdiff(internal_node_keys2, internal_node_keys1)...,) - - # leaf nodes named tuple - leaf_node_keys = (leaf_node_keys1..., leaf_node_keys2...,) - leaf_node_types = map(QuoteNode, (leaf_node_types1..., leaf_node_types2...,)) - leaf_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys1]..., - [Expr(:(.), :(choices2.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys2]...) - leaf_nodes = Expr(:call, - Expr(:curly, :NamedTuple, - QuoteNode(leaf_node_keys), - Expr(:curly, :Tuple, leaf_node_types...)), - leaf_node_values) - - # internal nodes named tuple - internal_node_keys = (internal_node_keys1_exclusive..., - internal_node_keys2_exclusive..., - colliding_internal_internal_keys...) - internal_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)) - for key in internal_node_keys1_exclusive]..., - [Expr(:(.), :(choices2.internal_nodes), QuoteNode(key)) - for key in internal_node_keys2_exclusive]..., - [Expr(:call, :merge, - Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)), - Expr(:(.), :(choices2.internal_nodes), QuoteNode(key))) - for key in colliding_internal_internal_keys]...) - internal_nodes = Expr(:call, - Expr(:curly, :NamedTuple, QuoteNode(internal_node_keys)), - internal_node_values) - - # construct assignment from named tuples - Expr(:call, :StaticChoiceMap, leaf_nodes, internal_nodes) -end - -export StaticChoiceMap -export pair, unpair - -####################### -# dynamic assignment # -####################### - -struct DynamicChoiceMap <: ChoiceMap - leaf_nodes::Dict{Any,Any} - internal_nodes::Dict{Any,Any} - function DynamicChoiceMap(leaf_nodes::Dict{Any,Any}, internal_nodes::Dict{Any,Any}) - new(leaf_nodes, internal_nodes) - end -end - -# invariant: all internal nodes are nonempty - -""" - struct DynamicChoiceMap <: ChoiceMap .. end - -A mutable map from arbitrary hierarchical addresses to values. - - choices = DynamicChoiceMap() - -Construct an empty map. - - choices = DynamicChoiceMap(tuples...) - -Construct a map containing each of the given (addr, value) tuples. -""" -function DynamicChoiceMap() - DynamicChoiceMap(Dict(), Dict()) -end - -function DynamicChoiceMap(tuples...) - choices = DynamicChoiceMap() - for tuple in tuples - if length(tuple) != 2 - error("Constructor accepts tuples of the form (address, value) only") - end - (addr, value) = tuple - choices[addr] = value - end - choices -end - -""" - choices = DynamicChoiceMap(other::ChoiceMap) - -Copy a choice map, returning a mutable choice map. -""" -function DynamicChoiceMap(other::ChoiceMap) - choices = DynamicChoiceMap() - for (addr, val) in get_values_shallow(other) - choices[addr] = val - end - for (addr, submap) in get_submaps_shallow(other) - set_submap!(choices, addr, DynamicChoiceMap(submap)) - end - choices -end - -""" - choices = choicemap() - -Construct an empty mutable choice map. -""" -function choicemap() - DynamicChoiceMap() -end - -""" - choices = choicemap(tuples...) - -Construct a mutable choice map initialized with given address, value tuples. -""" -function choicemap(tuples...) - DynamicChoiceMap(tuples...) -end - -get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() - -get_values_shallow(choices::DynamicChoiceMap) = choices.leaf_nodes - -get_submaps_shallow(choices::DynamicChoiceMap) = choices.internal_nodes - -has_value(choices::DynamicChoiceMap, addr::Pair) = _has_value(choices, addr) - -get_value(choices::DynamicChoiceMap, addr::Pair) = _get_value(choices, addr) - -get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::DynamicChoiceMap, addr) - if haskey(choices.internal_nodes, addr) - choices.internal_nodes[addr] - elseif haskey(choices.leaf_nodes, addr) - throw(KeyError(addr)) - else - EmptyChoiceMap() - end -end - -has_value(choices::DynamicChoiceMap, addr) = haskey(choices.leaf_nodes, addr) - -get_value(choices::DynamicChoiceMap, addr) = choices.leaf_nodes[addr] - -function Base.isempty(choices::DynamicChoiceMap) - isempty(choices.leaf_nodes) && isempty(choices.internal_nodes) -end - -# mutation (not part of the assignment interface) - -""" - set_value!(choices::DynamicChoiceMap, addr, value) - -Set the given value for the given address. - -Will cause any previous value or sub-assignment at this address to be deleted. -It is an error if there is already a value present at some prefix of the given address. - -The following syntactic sugar is provided: - - choices[addr] = value -""" -function set_value!(choices::DynamicChoiceMap, addr, value) - delete!(choices.internal_nodes, addr) - choices.leaf_nodes[addr] = value -end - -function set_value!(choices::DynamicChoiceMap, addr::Pair, value) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - node = choices.internal_nodes[first] - set_value!(node, rest, value) -end - -""" - set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) - -Replace the sub-assignment rooted at the given address with the given sub-assignment. -Set the given value for the given address. - -Will cause any previous value or sub-assignment at the given address to be deleted. -It is an error if there is already a value present at some prefix of address. -""" -function set_submap!(choices::DynamicChoiceMap, addr, new_node) - delete!(choices.leaf_nodes, addr) - delete!(choices.internal_nodes, addr) - if !isempty(new_node) - choices.internal_nodes[addr] = new_node - end -end - -function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - set_submap!(node, rest, new_node) -end - -Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) - -function _fill_array!(choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - leaf_keys_sorted = sort(collect(keys(choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - value = choices.leaf_nodes[key] - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for key in internal_node_keys_sorted - n_written = _fill_array!(get_submap(choices, key), arr, idx) - idx += n_written - end - idx - start_idx -end - -function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - @assert length(arr) >= start_idx - choices = DynamicChoiceMap() - leaf_keys_sorted = sort(collect(keys(proto_choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(proto_choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - (n_read, value) = _from_array(proto_choices.leaf_nodes[key], arr, idx) - idx += n_read - choices.leaf_nodes[key] = value - end - for key in internal_node_keys_sorted - (n_read, node) = _from_array(get_submap(proto_choices, key), arr, idx) - idx += n_read - choices.internal_nodes[key] = node - end - (idx - start_idx, choices) -end - -export DynamicChoiceMap -export choicemap -export set_value! -export set_submap! - - -####################################### -## vector combinator for assignments # -####################################### - -# TODO implement LeafVectorChoiceMap, which stores a vector of leaf nodes - -struct InternalVectorChoiceMap{T} <: ChoiceMap - internal_nodes::Vector{T} - is_empty::Bool -end - -function vectorize_internal(nodes::Vector{T}) where {T} - is_empty = all(map(isempty, nodes)) - InternalVectorChoiceMap(nodes, is_empty) -end - -# note some internal nodes may be empty - -get_address_schema(::Type{InternalVectorChoiceMap}) = VectorAddressSchema() - -Base.isempty(choices::InternalVectorChoiceMap) = choices.is_empty -has_value(choices::InternalVectorChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::InternalVectorChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::InternalVectorChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::InternalVectorChoiceMap, addr::Int) - if addr > 0 && addr <= length(choices.internal_nodes) - choices.internal_nodes[addr] - else - EmptyChoiceMap() - end -end - -function get_submaps_shallow(choices::InternalVectorChoiceMap) - ((i, choices.internal_nodes[i]) - for i=1:length(choices.internal_nodes) - if !isempty(choices.internal_nodes[i])) -end - -get_values_shallow(::InternalVectorChoiceMap) = () - -function _fill_array!(choices::InternalVectorChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for key=1:length(choices.internal_nodes) - n = _fill_array!(choices.internal_nodes[key], arr, idx) - idx += n - end - idx - start_idx -end - -function _from_array(proto_choices::InternalVectorChoiceMap{U}, arr::Vector{T}, start_idx::Int) where {T,U} - @assert length(arr) >= start_idx - nodes = Vector{U}(undef, length(proto_choices.internal_nodes)) - idx = start_idx - for key=1:length(proto_choices.internal_nodes) - (n_read, nodes[key]) = _from_array(proto_choices.internal_nodes[key], arr, idx) - idx += n_read - end - choices = InternalVectorChoiceMap(nodes, proto_choices.is_empty) - (idx - start_idx, choices) -end - -export InternalVectorChoiceMap -export vectorize_internal - - -#################### -# empty assignment # -#################### - -struct EmptyChoiceMap <: ChoiceMap end - -Base.isempty(::EmptyChoiceMap) = true -get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() -get_submaps_shallow(::EmptyChoiceMap) = () -get_values_shallow(::EmptyChoiceMap) = () - -_fill_array!(::EmptyChoiceMap, arr::Vector, start_idx::Int) = 0 -_from_array(::EmptyChoiceMap, arr::Vector, start_idx::Int) = (0, EmptyChoiceMap()) - -export EmptyChoiceMap - -############################################ -# Nested-dict–like accessor for choicemaps # -############################################ - -""" -Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than -the default syntax which looks like a flat dict of full keypaths. - -```jldoctest -julia> using Gen -julia> c = choicemap((:a, 1), - (:b => :c, 2)); -julia> cv = nested_view(c); -julia> c[:a] == cv[:a] -true -julia> c[:b => :c] == cv[:b][:c] -true -julia> length(cv) -2 -julia> length(cv[:b]) -1 -julia> sort(collect(keys(cv))) -[:a, :b] -julia> sort(collect(keys(cv[:b]))) -[:c] -``` -""" -struct ChoiceMapNestedView - choice_map::ChoiceMap -end - -function Base.getindex(choices::ChoiceMapNestedView, addr) - if has_value(choices.choice_map, addr) - return get_value(choices.choice_map, addr) - end - submap = get_submap(choices.choice_map, addr) - if isempty(submap) - throw(KeyError(addr)) - end - ChoiceMapNestedView(submap) -end - -function Base.iterate(c::ChoiceMapNestedView) - inner_iterator = Base.Iterators.flatten(( - get_values_shallow(c.choice_map), - ((k, ChoiceMapNestedView(v)) - for (k, v) in get_submaps_shallow(c.choice_map)))) - r = Base.iterate(inner_iterator) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -function Base.iterate(c::ChoiceMapNestedView, state) - (inner_iterator, inner_state) = state - r = Base.iterate(inner_iterator, inner_state) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -# TODO: Allow different implementations of this method depending on the -# concrete type of the `ChoiceMap`, so that an already-existing data structure -# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it -# exists. -Base.keys(cv::Gen.ChoiceMapNestedView) = (k for (k, v) in cv) - -function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) - a.choice_map == b.choice_map -end - -# Length of a `ChoiceMapNestedView` is number of leaf values + number of -# submaps. Motivation: This matches what `length` would return for the -# equivalent nested dict. -function Base.length(cv::ChoiceMapNestedView) - +(get_values_shallow(cv.choice_map) |> collect |> length, - get_submaps_shallow(cv.choice_map) |> collect |> length) -end - -function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) - Base.show(io, MIME"text/plain"(), c.choice_map) -end - -nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) - -# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling -# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and -# aux data together. - -export nested_view - -""" - selected_choices = get_selected(choices::ChoiceMap, selection::Selection) - -Filter the choice map to include only choices in the given selection. - -Returns a new choice map. -""" -function get_selected( - choices::ChoiceMap, selection::Selection) - output = choicemap() - for (key, value) in get_values_shallow(choices) - if (key in selection) - output[key] = value - end - end - for (key, submap) in get_submaps_shallow(choices) - subselection = selection[key] - set_submap!(output, key, get_selected(submap, subselection)) - end - output -end - -export get_selected diff --git a/src/choice_map/array_interface.jl b/src/choice_map/array_interface.jl new file mode 100644 index 000000000..cf9d0bd03 --- /dev/null +++ b/src/choice_map/array_interface.jl @@ -0,0 +1,106 @@ +### interface for to_array and fill_array ### + +""" + arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} + +Populate an array with values of choices in the given assignment. + +It is an error if each of the values cannot be coerced into a value of the +given type. + +Implementation + +The default implmentation of `fill_array` will populate the array by sorting +the addresses of the choicemap using the `sort` function, then iterating over +each submap in this order and filling the array for that submap. + +To override the default implementation of `to_array`, +a concrete subtype `T <: ChoiceMap` should implement the following method: + + n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Populate `arr` with values from the given assignment, starting at `start_idx`, +and return the number of elements in `arr` that were populated. + +(This is for performance; it is more efficient to fill in values in a preallocated array +by implementing `_fill_array!` than to construct discontiguous arrays for each submap and then merge them.) +""" +function to_array(choices::ChoiceMap, ::Type{T}) where {T} + arr = Vector{T}(undef, 32) + n = _fill_array!(choices, arr, 1) + @assert n <= length(arr) + resize!(arr, n) + arr +end + +function _fill_array!(c::ValueChoiceMap{<:T}, arr::Vector{T}, start_idx::Int) where {T} + if length(arr) < start_idx + resize!(arr, 2 * start_idx) + end + arr[start_idx] = get_value(c) + 1 +end +function _fill_array!(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + value = get_value(c) + if length(arr) < start_idx + length(value) + resize!(arr, 2 * (start_idx + length(value))) + end + arr[start_idx:start_idx+length(value)-1] = value + length(value) +end + +# default _fill_array! implementation +function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + key_to_submap = collect(get_submaps_shallow(choices)) + sort!(key_to_submap, by = ((key, submap),) -> key) + idx = start_idx + for (key, submap) in key_to_submap + n_written = _fill_array!(submap, arr, idx) + idx += n_written + end + idx - start_idx +end + +""" + choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) + +Return an assignment with the same address structure as a prototype +assignment, but with values read off from the given array. + +It is an error if the number of choices in the prototype assignment +is not equal to the length the array. + +The order in which addresses are populated with values from the array +should match the order in which the array is populated with values +in a call to `to_array(proto_choices, T)`. By default, +this means sorting the top-level addresses for `proto_choices` +and then filling in the submaps depth-first in this order. + +# Implementation + +To support `from_array`, a concrete subtype `T <: ChoiceMap` must implement +the following method: + + (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Return an assignment with the same address structure as a prototype assignment, +but with values read off from `arr`, starting at position `start_idx`. Return the +number of elements read from `arr`. +""" +function from_array(proto_choices::ChoiceMap, arr::Vector) + (n, choices) = _from_array(proto_choices, arr, 1) + if n != length(arr) + error("Dimension mismatch: $n, $(length(arr))") + end + choices +end + +function _from_array(::ValueChoiceMap, arr::Vector, start_idx::Int) + (1, ValueChoiceMap(arr[start_idx])) +end +function _from_array(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + n_read = length(get_value(c)) + (n_read, ValueChoiceMap(arr[start_idx:start_idx+n_read-1])) +end + +export to_array, from_array \ No newline at end of file diff --git a/src/choice_map/choice_map.jl b/src/choice_map/choice_map.jl new file mode 100644 index 000000000..6e0f5bd7b --- /dev/null +++ b/src/choice_map/choice_map.jl @@ -0,0 +1,304 @@ +######################### +# choice map interface # +######################### + +""" + ChoiceMapGetValueError + +The error returned when a user attempts to call `get_value` +on an choicemap for an address which does not contain a value in that choicemap. +""" +struct ChoiceMapGetValueError <: Exception end +showerror(io::IO, ex::ChoiceMapGetValueError) = (print(io, "ChoiceMapGetValueError: no value was found for the `get_value` call.")) + +""" + abstract type ChoiceMap end + +Abstract type for maps from hierarchical addresses to values. +""" +abstract type ChoiceMap end + +""" + get_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for each top-level address associated with `choices`. +(This includes `ValueChoiceMap`s.) +""" +function get_submaps_shallow end + +""" + get_submap(choices::ChoiceMap, addr) + +Return the submap at the given address, or `EmptyChoiceMap` +if there is no submap at the given address. +""" +function get_submap end + +# provide _get_submap so when users overwrite get_submap(choices::CustomChoiceMap, addr::Pair) +# they can just call _get_submap for convenience if they want +@inline function _get_submap(choices::ChoiceMap, addr::Pair) + (first, rest) = addr + submap = get_submap(choices, first) + get_submap(submap, rest) +end + +""" + has_value(choices::ChoiceMap) + +Returns true if `choices` is a `ValueChoiceMap`. + + has_value(choices::ChoiceMap, addr) + +Returns true if `choices` has a value stored at address `addr`. +""" +function has_value end +@inline has_value(::ChoiceMap) = false +@inline has_value(c::ChoiceMap, addr) = has_value(get_submap(c, addr)) + +""" + get_value(choices::ChoiceMap) + +Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; +throws a `ChoiceMapGetValueError` if `choices` is not a `ValueChoiceMap`. + + get_value(choices::ChoiceMap, addr) +Returns the value stored in the submap with address `addr` or throws +a `ChoiceMapGetValueError` if no value exists at this address. + +A syntactic sugar is `Base.getindex`: + + value = choices[addr] +""" +function get_value end +@inline get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) +@inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) + +""" + has_submap(choices::ChoiceMap, addr) +Return true if there is a non-empty sub-assignment at the given address. +""" +function has_submap end +@inline has_submap(choices::ChoiceMap, addr) = !isempty(get_submap(choices, addr)) + +""" +schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} + +Return the (top-level) address schema for the given choice map. +""" +function get_address_schema end + +# get_values_shallow and get_nonvalue_submaps_shallow are just filters on get_submaps_shallow +""" + get_values_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, value)` +for each value stored at a top-level address in `choices`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) +""" +function get_values_shallow(choices::ChoiceMap) + ( + (addr, get_value(submap)) + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) + ) +end + +""" + get_nonvalue_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for every top-level submap stored in `choices` which is +not a `ValueChoiceMap`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) +""" +function get_nonvalue_submaps_shallow(choices::ChoiceMap) + (addr_to_submap for addr_to_submap in get_submaps_shallow(choices) if !has_value(addr_to_submap[2])) +end + +# a choicemap is empty if it has no submaps and no value +Base.isempty(c::ChoiceMap) = all(((addr, submap),) -> isempty(submap), get_submaps_shallow(c)) && !has_value(c) + +""" + EmptyChoiceMap + +A choicemap with no submaps or values. +""" +struct EmptyChoiceMap <: ChoiceMap end + +@inline has_value(::EmptyChoiceMap, addr...) = false +@inline get_value(::EmptyChoiceMap) = throw(ChoiceMapGetValueError()) +@inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() +@inline Base.isempty(::EmptyChoiceMap) = true +@inline get_submaps_shallow(::EmptyChoiceMap) = () +@inline get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() +@inline Base.:(==)(::EmptyChoiceMap, ::EmptyChoiceMap) = true + +""" + ValueChoiceMap + +A leaf-node choicemap. Stores a single value. +""" +struct ValueChoiceMap{T} <: ChoiceMap + val::T +end + +@inline has_value(choices::ValueChoiceMap) = true +@inline get_value(choices::ValueChoiceMap) = choices.val +@inline get_submap(choices::ValueChoiceMap, addr) = EmptyChoiceMap() +@inline get_submaps_shallow(choices::ValueChoiceMap) = () +@inline Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val +@inline Base.:(==)(a::ValueChoiceMap, b::ChoiceMap) = false +@inline Base.:(==)(a::ChoiceMap, b::ValueChoiceMap) = false +@inline Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) +@inline get_address_schema(::Type{<:ValueChoiceMap}) = AllAddressSchema() + +""" + choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) + +Merge two choice maps. + +It is an error if the choice maps both have values at the same address, or if +one choice map has a value at an address that is the prefix of the address of a +value in the other choice map. +""" +function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) + choices = DynamicChoiceMap() + for (key, submap) in get_submaps_shallow(choices1) + set_submap!(choices, key, merge(submap, get_submap(choices2, key))) + end + for (key, submap) in get_submaps_shallow(choices2) + if isempty(get_submap(choices1, key)) + set_submap!(choices, key, submap) + end + end + choices +end +Base.merge(c::ChoiceMap, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::ChoiceMap) = c +Base.merge(c::ValueChoiceMap, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::ValueChoiceMap) = c +Base.merge(::ValueChoiceMap, ::ChoiceMap) = error("ValueChoiceMaps cannot be merged") +Base.merge(::ChoiceMap, ::ValueChoiceMap) = error("ValueChoiceMaps cannot be merged") + +""" +Variadic merge of choice maps. +""" +function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) + reduce(Base.merge, choices_rest; init=choices1) +end + +function Base.:(==)(a::ChoiceMap, b::ChoiceMap) + for (addr, submap) in get_submaps_shallow(a) + if get_submap(b, addr) != submap + return false + end + end + for (addr, submap) in get_submaps_shallow(b) + if get_submap(a, addr) != submap + return false + end + end + return true +end + +# This is modeled after +# https://github.com/JuliaLang/julia/blob/7bff5cdd0fab8d625e48b3a9bb4e94286f2ba18c/base/abstractdict.jl#L530-L537 +const hasha_seed = UInt === UInt64 ? 0x6d35bb51952d5539 : 0x952d5539 +function Base.hash(a::ChoiceMap, h::UInt) + hv = hasha_seed + for (addr, value) in get_values_shallow(a) + hv = xor(hv, hash(addr, hash(value))) + end + for (addr, submap) in get_submaps_shallow(a) + hv = xor(hv, hash(addr, hash(submap))) + end + return hash(hv, h) +end + +function Base.isapprox(a::ChoiceMap, b::ChoiceMap) + for (addr, submap) in get_submaps_shallow(a) + if !isapprox(get_submap(b, addr), submap) + return false + end + end + return true +end + +""" + selected_choices = get_selected(choices::ChoiceMap, selection::Selection) + +Filter the choice map to include only choices in the given selection. + +Returns a new choice map. +""" +function get_selected( + choices::ChoiceMap, selection::Selection) + # TODO: return a `FilteringChoiceMap` which does this filtering lazily! + output = choicemap() + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) && addr in selection + output[addr] = get_value(submap) + else + subselection = selection[addr] + set_submap!(output, addr, get_selected(submap, subselection)) + end + end + output +end + +function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + indent_vert_str = join(indent_vert) + indent_vert_last_str = join(indent_vert_last) + indent_str = join(indent) + indent_last_str = join(indent_last) + key_and_values = collect(get_values_shallow(choices)) + key_and_submaps = collect(get_nonvalue_submaps_shallow(choices)) + n = length(key_and_values) + length(key_and_submaps) + cur = 1 + for (key, value) in key_and_values + # For strings, `print` is what we want; `Base.show` includes quote marks. + # https://docs.julialang.org/en/v1/base/io-network/#Base.print + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + for (key, submap) in key_and_submaps + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") + _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) + _show_pretty(io, choices, 0, ()) +end + +export ChoiceMap, ValueChoiceMap, EmptyChoiceMap +export _get_submap, get_submap, get_submaps_shallow, has_submap +export get_value, has_value +export get_values_shallow, get_nonvalue_submaps_shallow +export get_address_schema, get_selected +export ChoiceMapGetValueError + +include("array_interface.jl") +include("dynamic_choice_map.jl") +include("static_choice_map.jl") +include("nested_view.jl") \ No newline at end of file diff --git a/src/choice_map/dynamic_choice_map.jl b/src/choice_map/dynamic_choice_map.jl new file mode 100644 index 000000000..112bfc92e --- /dev/null +++ b/src/choice_map/dynamic_choice_map.jl @@ -0,0 +1,153 @@ +####################### +# dynamic assignment # +####################### + +""" + struct DynamicChoiceMap <: ChoiceMap .. end + +A mutable map from arbitrary hierarchical addresses to values. + + choices = DynamicChoiceMap() + +Construct an empty map. + + choices = DynamicChoiceMap(tuples...) + +Construct a map containing each of the given (addr, value) tuples. +""" +struct DynamicChoiceMap <: ChoiceMap + submaps::Dict{Any, ChoiceMap} + function DynamicChoiceMap() + new(Dict()) + end +end + +function DynamicChoiceMap(tuples...) + choices = DynamicChoiceMap() + for tuple in tuples + if length(tuple) != 2 + error("Constructor accepts tuples of the form (address, value) only") + end + (addr, value) = tuple + choices[addr] = value + end + choices +end + +""" + choices = DynamicChoiceMap(other::ChoiceMap) + +Copy a choice map, returning a mutable choice map. +""" +function DynamicChoiceMap(other::ChoiceMap) + choices = DynamicChoiceMap() + for (addr, submap) in get_submaps_shallow(other) + if submap isa ValueChoiceMap + set_submap!(choices, addr, submap) + else + set_submap!(choices, addr, DynamicChoiceMap(submap)) + end + end + choices +end + +DynamicChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a DynamicChoiceMap") + +""" + choices = choicemap() + +Construct an empty mutable choice map. +""" +function choicemap() + DynamicChoiceMap() +end + +""" + choices = choicemap(tuples...) + +Construct a mutable choice map initialized with given address, value tuples. +""" +function choicemap(tuples...) + DynamicChoiceMap(tuples...) +end + +@inline get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps +@inline get_submap(choices::DynamicChoiceMap, addr) = get(choices.submaps, addr, EmptyChoiceMap()) +@inline get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) +@inline Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) + +# mutation (not part of the assignment interface) + +""" + set_value!(choices::DynamicChoiceMap, addr, value) + +Set the given value for the given address. + +Will cause any previous value or sub-assignment at this address to be deleted. +It is an error if there is already a value present at some prefix of the given address. + +The following syntactic sugar is provided: + + choices[addr] = value +""" +function set_value!(choices::DynamicChoiceMap, addr, value) + delete!(choices.submaps, addr) + choices.submaps[addr] = ValueChoiceMap(value) +end + +function set_value!(choices::DynamicChoiceMap, addr::Pair, value) + (first, rest) = addr + if !haskey(choices.submaps, first) + choices.submaps[first] = DynamicChoiceMap() + elseif has_value(choices.submaps[first]) + error("Tried to create assignment at $first but there was already a value there.") + end + set_value!(choices.submaps[first], rest, value) +end + +""" + set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) + +Replace the sub-assignment rooted at the given address with the given sub-assignment. +Set the given value for the given address. + +Will cause any previous value or sub-assignment at the given address to be deleted. +It is an error if there is already a value present at some prefix of address. +""" +function set_submap!(choices::DynamicChoiceMap, addr, new_node::ChoiceMap) + delete!(choices.submaps, addr) + if !isempty(new_node) + choices.submaps[addr] = new_node + end +end + +function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node::ChoiceMap) + (first, rest) = addr + if !haskey(choices.submaps, first) + choices.submaps[first] = DynamicChoiceMap() + elseif has_value(choices.submaps[first]) + error("Tried to create assignment at $first but there was already a value there.") + end + set_submap!(choices.submaps[first], rest, new_node) +end + +Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) + +function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + choices = DynamicChoiceMap() + keys_sorted = sort(collect(keys(proto_choices.submaps))) + idx = start_idx + for key in keys_sorted + (n_read, submap) = _from_array(proto_choices.submaps[key], arr, idx) + idx += n_read + choices.submaps[key] = submap + end + (idx - start_idx, choices) +end + +get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() + +export DynamicChoiceMap +export choicemap +export set_value! +export set_submap! \ No newline at end of file diff --git a/src/choice_map/nested_view.jl b/src/choice_map/nested_view.jl new file mode 100644 index 000000000..68add0a05 --- /dev/null +++ b/src/choice_map/nested_view.jl @@ -0,0 +1,80 @@ +############################################ +# Nested-dict–like accessor for choicemaps # +############################################ + +""" +Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than +the default syntax which looks like a flat dict of full keypaths. + +```jldoctest +julia> using Gen +julia> c = choicemap((:a, 1), + (:b => :c, 2)); +julia> cv = nested_view(c); +julia> c[:a] == cv[:a] +true +julia> c[:b => :c] == cv[:b][:c] +true +julia> length(cv) +2 +julia> length(cv[:b]) +1 +julia> sort(collect(keys(cv))) +[:a, :b] +julia> sort(collect(keys(cv[:b]))) +[:c] +``` +""" +struct ChoiceMapNestedView + choice_map::ChoiceMap +end + +ChoiceMapNestedView(cm::ValueChoiceMap) = get_value(cm) +ChoiceMapNestedView(::EmptyChoiceMap) = error("Can't convert an emptychoicemap to nested view.") + +function Base.getindex(choices::ChoiceMapNestedView, addr) + ChoiceMapNestedView(get_submap(choices.choice_map, addr)) +end + +function Base.iterate(c::ChoiceMapNestedView) + itr = ((k, ChoiceMapNestedView(s)) for (k, s) in get_submaps_shallow(c.choice_map)) + r = Base.iterate(itr) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +function Base.iterate(c::ChoiceMapNestedView, state) + (itr, st) = state + r = Base.iterate(itr, st) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +# TODO: Allow different implementations of this method depending on the +# concrete type of the `ChoiceMap`, so that an already-existing data structure +# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it +# exists. +Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) + +Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) = a.choice_map == b.choice_map + +function Base.length(cv::ChoiceMapNestedView) + length(collect(get_submaps_shallow(cv.choice_map))) +end +function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) + Base.show(io, MIME"text/plain"(), c.choice_map) +end + +nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) + +# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling +# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and +# aux data together. + +export nested_view \ No newline at end of file diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl new file mode 100644 index 000000000..587fc6ee5 --- /dev/null +++ b/src/choice_map/static_choice_map.jl @@ -0,0 +1,160 @@ +###################### +# static assignment # +###################### + +struct StaticChoiceMap{Addrs, SubmapTypes} <: ChoiceMap + submaps::NamedTuple{Addrs, SubmapTypes} + function StaticChoiceMap(submaps::NamedTuple{Addrs, SubmapTypes}) where {Addrs, SubmapTypes <: NTuple{n, ChoiceMap} where n} + new{Addrs, SubmapTypes}(submaps) + end +end + +function StaticChoiceMap(;addrs_to_vals_and_maps...) + addrs = Tuple(addr for (addr, val_or_map) in addrs_to_vals_and_maps) + maps = Tuple(val_or_map isa ChoiceMap ? val_or_map : ValueChoiceMap(val_or_map) for (addr, val_or_map) in addrs_to_vals_and_maps) + StaticChoiceMap(NamedTuple{addrs}(maps)) +end + +@inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) +@inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) + +# TODO: would it be faster to do static_get_submap? +function get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, addr::Symbol) where {Addrs, SubmapTypes} + if addr in Addrs + choices.submaps[addr] + else + EmptyChoiceMap() + end +end + +@generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} + if A in Addrs + quote choices.submaps[A] end + else + quote EmptyChoiceMap() end + end +end +@inline static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() + +@inline static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) +@inline static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) + +# convert a nonvalue choicemap all of whose top-level-addresses +# are symbols into a staticchoicemap at the top level +function StaticChoiceMap(other::ChoiceMap) + keys_and_nodes = get_submaps_shallow(other) + if length(keys_and_nodes) > 0 + addrs = Tuple(key for (key, _) in keys_and_nodes) + submaps = Tuple(submap for (_, submap) in keys_and_nodes) + else + addrs = () + submaps = () + end + StaticChoiceMap(NamedTuple{addrs}(submaps)) +end +StaticChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a StaticChoiceMap") +StaticChoiceMap(::NamedTuple{(),Tuple{}}) = EmptyChoiceMap() + +# TODO: deep conversion to static choicemap + +""" + choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) + +Return an assignment that contains `choices1` as a sub-assignment under `key1` +and `choices2` as a sub-assignment under `key2`. +""" +function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) + StaticChoiceMap(NamedTuple{(key1, key2)}((choices1, choices2))) +end + +""" + (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) + +Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. + +It is an error if there are any submaps at keys other than `key1` and `key2`. +""" +function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) + if length(collect(get_submaps_shallow(choices))) != 2 + error("Not a pair") + end + (get_submap(choices, key1), get_submap(choices, key2)) +end + +@generated function Base.merge(choices1::StaticChoiceMap{Addrs1, SubmapTypes1}, + choices2::StaticChoiceMap{Addrs2, SubmapTypes2}) where {Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} + + addr_to_type1 = Dict{Symbol, Type{<:ChoiceMap}}() + addr_to_type2 = Dict{Symbol, Type{<:ChoiceMap}}() + for (i, addr) in enumerate(Addrs1) + addr_to_type1[addr] = SubmapTypes1.parameters[i] + end + for (i, addr) in enumerate(Addrs2) + addr_to_type2[addr] = SubmapTypes2.parameters[i] + end + + merged_addrs = Tuple(union(Set(Addrs1), Set(Addrs2))) + submap_exprs = [] + + for addr in merged_addrs + type1 = get(addr_to_type1, addr, EmptyChoiceMap) + type2 = get(addr_to_type2, addr, EmptyChoiceMap) + if ((type1 <: ValueChoiceMap && type2 != EmptyChoiceMap) + || (type2 <: ValueChoiceMap && type1 != EmptyChoiceMap)) + error( "One choicemap has a value at address $addr; the other is nonempty at $addr. Cannot merge.") + end + if type1 <: EmptyChoiceMap + push!(submap_exprs, + quote choices2.submaps.$addr end + ) + elseif type2 <: EmptyChoiceMap + push!(submap_exprs, + quote choices1.submaps.$addr end + ) + else + push!(submap_exprs, + quote merge(choices1.submaps.$addr, choices2.submaps.$addr) end + ) + end + end + + quote + StaticChoiceMap(NamedTuple{$merged_addrs}(($(submap_exprs...),))) + end +end + +@generated function _from_array(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, + arr::Vector{T}, start_idx::Int) where {T, Addrs, SubmapTypes} + + perm = sortperm(collect(Addrs)) + sorted_addrs = Addrs[perm] + submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) + + exprs = [quote idx = start_idx end] + + for (idx, addr) in zip(perm, sorted_addrs) + submap_var_name = gensym(addr) + submap_var_names[idx] = submap_var_name + push!(exprs, + quote + (n_read, $submap_var_name) = _from_array(proto_choices.submaps.$addr, arr, idx) + idx += n_read + end + ) + end + + quote + $(exprs...) + submaps = NamedTuple{Addrs}(( $(submap_var_names...), )) + choices = StaticChoiceMap(submaps) + (idx - start_idx, choices) + end +end + +function get_address_schema(::Type{StaticChoiceMap{Addrs, SubmapTypes}}) where {Addrs, SubmapTypes} + StaticAddressSchema(Set(Addrs)) +end + +export StaticChoiceMap +export pair, unpair +export static_get_submap, static_get_value \ No newline at end of file diff --git a/src/distribution.jl b/src/distribution.jl new file mode 100644 index 000000000..fdb505512 --- /dev/null +++ b/src/distribution.jl @@ -0,0 +1,173 @@ +############################### +# Core Distribution Interface # +############################### + +struct DistributionTrace{T, Dist} <: Trace + val::T + args + score::Float64 + dist::Dist +end +@inline dist(tr::DistributionTrace{T, Dist}) where {T, Dist} = tr.dist + +abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end +DistributionTrace{T, Dist}(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...), dist) +@inline DistributionTrace(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...), dist) + +# we need to know the specific distribution in the trace type so the compiler can specialize GFI calls fully +@inline get_trace_type(::Dist) where {T, Dist <: Distribution{T}} = DistributionTrace{T, Dist} + +function Base.convert(::Type{<:DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} + DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, tr.score, tr.dist) +end + +""" + val::T = random(dist::Distribution{T}, args...) + +Sample a random choice from the given distribution with the given arguments. +""" +function random end + +""" + lpdf = logpdf(dist::Distribution{T}, value::T, args...) + +Evaluate the log probability (density) of the value. +""" +function logpdf end + +""" + has::Bool = has_output_grad(dist::Distribution) + +Return true if the distribution computes the gradient of the logpdf with respect to the value of the random choice. +""" +function has_output_grad end + +""" + grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...) + +Compute the gradient of the logpdf with respect to the value, and each of the arguments. + +If `has_output_grad` returns false, then the first element of the returned tuple is `nothing`. +Otherwise, the first element of the tuple is the gradient with respect to the value. +If the return value of `has_argument_grads` has a false value for at position `i`, then the `i+1`th element of the returned tuple has value `nothing`. +Otherwise, this element contains the gradient with respect to the `i`th argument. +""" +function logpdf_grad end + +is_discrete(::Distribution) = false # default + +# NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl + +get_return_type(::Distribution{T}) where {T} = T + + +############################## +# Distribution GFI Interface # +############################## + +@inline Base.getindex(trace::DistributionTrace) = trace.val +@inline Gen.get_args(trace::DistributionTrace) = trace.args +@inline Gen.get_choices(trace::DistributionTrace) = ValueChoiceMap(trace.val) # should be able to get type of val +@inline Gen.get_retval(trace::DistributionTrace) = trace.val +@inline Gen.get_gen_fn(trace::DistributionTrace) = dist(trace) +@inline Gen.get_score(trace::DistributionTrace) = trace.score +@inline Gen.project(trace::DistributionTrace, ::EmptySelection) = 0. +@inline Gen.project(trace::DistributionTrace, ::AllSelection) = get_score(trace) +@inline Gen.project(trace::DistributionTrace, c::ComplementSelection) = project_complement(trace, c.complement) +@inline project_complement(trace::DistributionTrace, ::EmptySelection) = get_score(trace) +@inline project_complement(trace::DistributionTrace, ::AllSelection) = 0. +@inline project_complement(trace::DistributionTrace, c::ComplementSelection) = Gen.project(trace, c.complement) +@inline function Gen.simulate(dist::Distribution, args::Tuple) + val = random(dist, args...) + DistributionTrace(val, args, dist) +end +@inline Gen.generate(dist::Distribution, args::Tuple, ::EmptyChoiceMap) = (simulate(dist, args), 0.) +@inline function Gen.generate(dist::Distribution, args::Tuple, constraints::ValueChoiceMap) + tr = DistributionTrace(get_value(constraints), args, dist) + weight = get_score(tr) + (tr, weight) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::ValueChoiceMap) + new_tr = DistributionTrace(get_value(constraints), args, dist(tr)) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, UnknownChange(), get_choices(tr)) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::EmptyChoiceMap) + new_tr = DistributionTrace(tr.val, args, dist(tr)) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, NoChange(), EmptyChoiceMap()) +end +@inline Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, constraints::EmptyChoiceMap) where {n} = (tr, 0., NoChange()) +# TODO: do I need an update method to handle empty choicemaps which are not `EmptyChoiceMap`s? +@inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, selection::EmptySelection) where {n} = (tr, 0., NoChange()) +@inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::EmptySelection) + new_tr = DistributionTrace(tr.val, args, dist(tr)) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, NoChange()) +end +@inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::AllSelection) + new_val = random(dist(tr), args...) + new_tr = DistributionTrace(new_val, args, dist(tr)) + (new_tr, 0., UnknownChange()) +end +@inline function Gen.propose(dist::Distribution, args::Tuple) + val = random(dist, args...) + score = logpdf(dist, val, args...) + (ValueChoiceMap(val), score, val) +end +@inline function Gen.assess(dist::Distribution, args::Tuple, choices::ValueChoiceMap) + weight = logpdf(dist, get_value(choices), args...) + (weight, choices.val) +end + + +# Gradient-based methods +@inline Gen.accepts_output_grad(dist::Distribution) = has_output_grad(dist) + +function Gen.choice_gradients(tr::DistributionTrace, ::AllSelection, retgrad) + if !has_output_grad(dist(tr)) + error("Distribution $(dist(tr)) does not compute gradient of logpdf with respect to value") + end + grads = logpdf_grad(dist(tr), tr.val, tr.args...) + output_grad = grads[1] + arg_grads = grads[2:end] + choice_values = ValueChoiceMap(tr.val) + choice_grads = ValueChoiceMap(isnothing(retgrad) ? output_grad : output_grad .+ retgrad) + return arg_grads, choice_values, choice_grads +end + +@inline function Gen.choice_gradients(tr::DistributionTrace, ::EmptySelection, retgrad) + arg_grads = logpdf_grad(dist(tr), tr.val, tr.args...)[2:end] + choice_values = EmptyChoiceMap() + choice_grads = EmptyChoiceMap() + return arg_grads, choice_values, choice_grads +end + +function Gen.choice_gradients(tr::DistributionTrace, c::ComplementSelection, retgrad) + if c.complement isa EmptySelection + return choice_gradients(tr, AllSelection(), retgrad) + elseif c.complement isa AllSelection + return choice_gradients(tr, EmptySelection(), retgrad) + elseif c.complement isa ComplementSelection + return choice_gradients(tr, c.complement.complement, retgrad) + else + error("Choice gradients not implemented for generic complement selection") + end +end + +@inline function Gen.accumulate_param_gradients!(tr::DistributionTrace, retgrad, scale_factor) + arg_grads = logpdf_grad(dist(tr), tr.val, tr.args...)[2:end] + return arg_grads +end + + +########### +# Exports # +########### + +export Distribution +export random +export logpdf +export logpdf_grad +export has_output_grad +export is_discrete diff --git a/src/dsl/dsl.jl b/src/dsl/dsl.jl index 77e794663..1d08c62a2 100644 --- a/src/dsl/dsl.jl +++ b/src/dsl/dsl.jl @@ -7,6 +7,7 @@ const DSL_ARG_GRAD_ANNOTATION = :grad const DSL_RET_GRAD_ANNOTATION = :grad const DSL_TRACK_DIFFS_ANNOTATION = :diffs const DSL_NO_JULIA_CACHE_ANNOTATION = :nojuliacache +const DSL_MACROS = Set([Symbol("@trace"), Symbol("@param")]) struct Argument name::Symbol @@ -81,6 +82,8 @@ include("dynamic.jl") include("static.jl") function desugar_tildes(expr) + trace_ref = GlobalRef(@__MODULE__, Symbol("@trace")) + line_num = LineNumberNode(1, :none) MacroTools.postwalk(expr) do e # Replace tilde statements with :gentrace expressions if MacroTools.@capture(e, {*} ~ rhs_call) diff --git a/src/dsl/static.jl b/src/dsl/static.jl index ba57c635e..807f65df0 100644 --- a/src/dsl/static.jl +++ b/src/dsl/static.jl @@ -47,10 +47,6 @@ end split_addr!(keys, addr_expr::QuoteNode) = push!(keys, addr_expr) split_addr!(keys, addr_expr::Symbol) = push!(keys, addr_expr) -"Construct choice-at or call-at combinator depending on type." -choice_or_call_at(gen_fn::GenerativeFunction, addr_typ) = call_at(gen_fn, addr_typ) -choice_or_call_at(dist::Distribution, addr_typ) = choice_at(dist, addr_typ) - "Generate informative node name for a Julia expression." gen_node_name(arg::Any) = gensym(string(arg)) gen_node_name(arg::Expr) = gensym(arg.head) @@ -74,12 +70,12 @@ function parse_trace_expr!(stmts, bindings, fn, args, addr, __module__) end addr = keys[1].value # Get top level address if length(keys) > 1 - # For each nesting level, wrap gen_fn_or_dist within choice_at / call_at + # For each nesting level, wrap gen_fn_or_dist within call_at for key in keys[2:end] push!(stmts, :($(esc(gen_fn_or_dist)) = - choice_or_call_at($(esc(gen_fn_or_dist)), Any))) + call_at($(esc(gen_fn_or_dist)), Any))) end - # Append the nested addresses as arguments to choice_at / call_at + # Append the nested addresses as arguments to call_at args = [args; reverse(keys[2:end])] end # Handle arguments to the traced call @@ -209,7 +205,7 @@ function parse_and_rewrite_trace!(stmts, bindings, expr, __module__) if MacroTools.@capture(expr, e_gentrace) # Parse "@trace(f(xs...), addr)" and return fresh variable call, addr = expr.args - if addr == nothing static_dsl_syntax_error(expr, "Address required.") end + if addr === nothing static_dsl_syntax_error(expr, "Address required.") end fn, args = call.args[1], call.args[2:end] parse_trace_expr!(stmts, bindings, fn, args, something(addr), __module__) else diff --git a/src/dynamic/assess.jl b/src/dynamic/assess.jl index c583d5079..0bf37a077 100644 --- a/src/dynamic/assess.jl +++ b/src/dynamic/assess.jl @@ -9,22 +9,6 @@ function GFAssessState(choices, params::Dict{Symbol,Any}) GFAssessState(choices, 0., AddressVisitor(), params) end -function traceat(state::GFAssessState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # get return value - retval = get_value(state.choices, key) - - # update weight - state.weight += logpdf(dist, retval, args...) - - retval -end - function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local retval::T diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index 9e02c0657..238c173d7 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -2,34 +2,6 @@ function maybe_track(arg, has_argument_grad::Bool, tape) has_argument_grad ? track(arg, tape) : arg end -@noinline function ReverseDiff.special_reverse_exec!( - instruction::ReverseDiff.SpecialInstruction{D}) where {D <: Distribution} - dist::D = instruction.func - args_maybe_tracked = instruction.input - score_tracked = instruction.output - arg_grads = logpdf_grad(dist, map(value, args_maybe_tracked)...) - value_tracked = args_maybe_tracked[1] - value_grad = arg_grads[1] - if istracked(value_tracked) - if has_output_grad(dist) - increment_deriv!(value_tracked, value_grad * deriv(score_tracked)) - else - error("Gradient required but not available for return value of distribution $dist") - end - end - for (i, (arg_maybe_tracked, grad, has_grad)) in enumerate( - zip(args_maybe_tracked[2:end], arg_grads[2:end], has_argument_grads(dist))) - if istracked(arg_maybe_tracked) - if has_grad - increment_deriv!(arg_maybe_tracked, grad * deriv(score_tracked)) - else - error("Gradient required but not available for argument $i of $dist") - end - end - end - nothing -end - ################### # accumulate_param_gradients! # @@ -70,19 +42,6 @@ function read_param(state::GFBackpropParamsState, name::Symbol) value end -function traceat(state::GFBackpropParamsState, dist::Distribution{T}, - args_maybe_tracked, key) where {T} - local retval::T - visit!(state.visitor, key) - retval = get_choice(state.trace, key).retval - args = map(value, args_maybe_tracked) - score_tracked = track(logpdf(dist, retval, args...), state.tape) - record!(state.tape, ReverseDiff.SpecialInstruction, dist, - (retval, args_maybe_tracked...,), score_tracked) - state.score += score_tracked - retval -end - struct BackpropParamsRecord gen_fn::GenerativeFunction subtrace::Any @@ -271,28 +230,6 @@ function fill_map!( fill_submaps!(map, tracked_trie, mode) end -function traceat(state::GFBackpropTraceState, dist::Distribution{T}, - args_maybe_tracked, key) where {T} - local retval::T - visit!(state.visitor, key) - retval = get_choice(state.trace, key).retval - args = map(value, args_maybe_tracked) - score_tracked = track(logpdf(dist, retval, args...), state.tape) - if key in state.selection - tracked_retval = track(retval, state.tape) - set_leaf_node!(state.tracked_choices, key, tracked_retval) - record!(state.tape, ReverseDiff.SpecialInstruction, dist, - (tracked_retval, args_maybe_tracked...,), score_tracked) - state.score += score_tracked - return tracked_retval - else - record!(state.tape, ReverseDiff.SpecialInstruction, dist, - (retval, args_maybe_tracked...,), score_tracked) - state.score += score_tracked - return retval - end -end - struct BackpropTraceRecord gen_fn::GenerativeFunction subtrace::Any diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index d83055444..ed3452aa4 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -126,48 +126,33 @@ function visit!(visitor::AddressVisitor, addr) push!(visitor.visited, addr) end +all_visited(::Selection, ::ValueChoiceMap) = false +all_visited(::AllSelection, ::ValueChoiceMap) = true function all_visited(visited::Selection, choices::ChoiceMap) - allvisited = true - for (key, _) in get_values_shallow(choices) - allvisited = allvisited && (key in visited) - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - allvisited = allvisited && all_visited(subvisited, submap) + if !all_visited(visited[key], submap) + return false end end - allvisited + return true end +get_unvisited(::Selection, v::ValueChoiceMap) = v +get_unvisited(::AllSelection, v::ValueChoiceMap) = EmptyChoiceMap() function get_unvisited(visited::Selection, choices::ChoiceMap) unvisited = choicemap() - for (key, _) in get_values_shallow(choices) - if !(key in visited) - set_value!(unvisited, key, get_value(choices, key)) - end - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - sub_unvisited = get_unvisited(subvisited, submap) - set_submap!(unvisited, key, sub_unvisited) - end + sub_unvisited = get_unvisited(visited[key], submap) + set_submap!(unvisited, key, sub_unvisited) end unvisited end get_visited(visitor) = visitor.visited -function check_no_submap(constraints::ChoiceMap, addr) +function check_is_empty(constraints::ChoiceMap, addr) if !isempty(get_submap(constraints, addr)) - error("Expected a value at address $addr but found a sub-assignment") - end -end - -function check_no_value(constraints::ChoiceMap, addr) - if has_value(constraints, addr) - error("Expected a sub-assignment at address $addr but found a value") + error("Expected a value or EmptyChoiceMap at address $addr but found a sub-assignment") end end diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index a89e0c352..b4ea5ac3d 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -11,38 +11,6 @@ function GFGenerateState(gen_fn, args, constraints, params) GFGenerateState(trace, constraints, 0., AddressVisitor(), params) end -function traceat(state::GFGenerateState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for constraints at this key - constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) - - # get return value - if constrained - retval = get_value(state.constraints, key) - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # add to the trace - add_choice!(state.trace, key, retval, score) - - # increment weight - if constrained - state.weight += score - end - - retval -end - function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local subtrace::U diff --git a/src/dynamic/project.jl b/src/dynamic/project.jl index 358398e55..81b45b3c6 100644 --- a/src/dynamic/project.jl +++ b/src/dynamic/project.jl @@ -1,15 +1,9 @@ -function project_recurse(trie::Trie{Any,ChoiceOrCallRecord}, +function project_recurse(trie::Trie{Any, CallRecord}, selection::Selection) weight = 0. - for (key, choice_or_call) in get_leaf_nodes(trie) - if choice_or_call.is_choice - if key in selection - weight += choice_or_call.score - end - else - subselection = selection[key] - weight += project(choice_or_call.subtrace_or_retval, subselection) - end + for (key, call) in get_leaf_nodes(trie) + subselection = selection[key] + weight += project(call.subtrace, subselection) end for (key, subtrie) in get_internal_nodes(trie) subselection = selection[key] diff --git a/src/dynamic/propose.jl b/src/dynamic/propose.jl index 7c630cc19..32fc95da2 100644 --- a/src/dynamic/propose.jl +++ b/src/dynamic/propose.jl @@ -9,25 +9,6 @@ function GFProposeState(params::Dict{Symbol,Any}) GFProposeState(choicemap(), 0., AddressVisitor(), params) end -function traceat(state::GFProposeState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # sample return value - retval = random(dist, args...) - - # update assignment - set_value!(state.choices, key, retval) - - # update weight - state.weight += logpdf(dist, retval, args...) - - retval -end - function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local retval::T diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl index 13d14d86f..a4006a6c8 100644 --- a/src/dynamic/regenerate.jl +++ b/src/dynamic/regenerate.jl @@ -14,48 +14,6 @@ function GFRegenerateState(gen_fn, args, prev_trace, 0., visitor, params) end -function traceat(state::GFRegenerateState, dist::Distribution{T}, - args, key) where {T} - local prev_retval::T - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for previous choice at this key - has_previous = has_choice(state.prev_trace, key) - if has_previous - prev_choice = get_choice(state.prev_trace, key) - prev_retval = prev_choice.retval - prev_score = prev_choice.score - end - - # check whether the key was selected - in_selection = key in state.selection - - # get return value - if has_previous && in_selection - retval = random(dist, args...) - elseif has_previous - retval = prev_retval - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # update weight - if has_previous && !in_selection - state.weight += score - prev_score - end - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local prev_retval::T @@ -101,13 +59,11 @@ function splice(state::GFRegenerateState, gen_fn::DynamicDSLFunction, retval end -function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function regenerate_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::EmptySelection) noise = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - if !choice_or_call.is_choice - noise += choice_or_call.noise - end + for (key, call) in get_leaf_nodes(prev_trie) + noise += call.noise end for (key, subtrie) in get_internal_nodes(prev_trie) noise += regenerate_delete_recurse(subtrie, EmptySelection()) @@ -115,12 +71,12 @@ function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, noise end -function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function regenerate_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::DynamicSelection) noise = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - if !(key in visited) && !choice_or_call.is_choice - noise += choice_or_call.noise + for (key, call) in get_leaf_nodes(prev_trie) + if !(key in visited) + noise += call.noise end end for (key, subtrie) in get_internal_nodes(prev_trie) diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index 7db1a213a..6fa0bc031 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -9,24 +9,6 @@ function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params) GFSimulateState(trace, AddressVisitor(), params) end -function traceat(state::GFSimulateState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - retval = random(dist, args...) - - # compute logpdf - score = logpdf(dist, retval, args...) - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local subtrace::U diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 8c02eceb5..0a169d737 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -1,96 +1,37 @@ -struct ChoiceRecord{T} - retval::T - score::Float64 -end - struct CallRecord{T} subtrace::T score::Float64 noise::Float64 end -struct ChoiceOrCallRecord{T} - subtrace_or_retval::T - score::Float64 - noise::Float64 # if choice then NaN - is_choice::Bool -end - -function ChoiceRecord(record::ChoiceOrCallRecord) - if !record.is_choice - error("Found call but expected choice") - end - ChoiceRecord(record.subtrace_or_retval, record.score) -end - -function CallRecord(record::ChoiceOrCallRecord) - if record.is_choice - error("Found choice but expected call") - end - CallRecord(record.subtrace_or_retval, record.score, record.noise) -end - mutable struct DynamicDSLTrace{T} <: Trace gen_fn::T - trie::Trie{Any,ChoiceOrCallRecord} - isempty::Bool + trie::Trie{Any,CallRecord} score::Float64 noise::Float64 args::Tuple retval::Any function DynamicDSLTrace{T}(gen_fn::T, args) where {T} - trie = Trie{Any,ChoiceOrCallRecord}() + trie = Trie{Any,CallRecord}() # retval is not known yet - new(gen_fn, trie, true, 0, 0, args) + new(gen_fn, trie, 0, 0, args) end end set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval) -function has_choice(trace::DynamicDSLTrace, addr) - haskey(trace.trie, addr) && trace.trie[addr].is_choice -end - -function has_call(trace::DynamicDSLTrace, addr) - haskey(trace.trie, addr) && !trace.trie[addr].is_choice -end - -function get_choice(trace::DynamicDSLTrace, addr) - choice = trace.trie[addr] - if !choice.is_choice - throw(KeyError(addr)) - end - ChoiceRecord(choice) -end - -function get_call(trace::DynamicDSLTrace, addr) - call = trace.trie[addr] - if call.is_choice - throw(KeyError(addr)) - end - CallRecord(call) -end - -function add_choice!(trace::DynamicDSLTrace, addr, retval, score) - if haskey(trace.trie, addr) - error("Value or subtrace already present at address $addr. - The same address cannot be reused for multiple random choices.") - end - trace.trie[addr] = ChoiceOrCallRecord(retval, score, NaN, true) - trace.score += score - trace.isempty = false -end +has_call(trace::DynamicDSLTrace, addr) = haskey(trace.trie, addr) +get_call(trace::DynamicDSLTrace, addr) = trace.trie[addr] function add_call!(trace::DynamicDSLTrace, addr, subtrace) if haskey(trace.trie, addr) - error("Value or subtrace already present at address $addr. + error("Subtrace already present at address $addr. The same address cannot be reused for multiple random choices.") end score = get_score(subtrace) noise = project(subtrace, EmptySelection()) submap = get_choices(subtrace) - trace.isempty = trace.isempty && isempty(submap) - trace.trie[addr] = ChoiceOrCallRecord(subtrace, score, noise, false) + trace.trie[addr] = CallRecord(subtrace, score, noise) trace.score += score trace.noise += noise end @@ -106,69 +47,28 @@ get_gen_fn(trace::DynamicDSLTrace) = trace.gen_fn ## get_choices ## -function get_choices(trace::DynamicDSLTrace) - if !trace.isempty - DynamicDSLChoiceMap(trace.trie) # see below - else - EmptyChoiceMap() - end -end +get_choices(trace::DynamicDSLTrace) = DynamicDSLChoiceMap(trace.trie) struct DynamicDSLChoiceMap <: ChoiceMap - trie::Trie{Any,ChoiceOrCallRecord} + trie::Trie{Any,CallRecord} end get_address_schema(::Type{DynamicDSLChoiceMap}) = DynamicAddressSchema() -Base.isempty(::DynamicDSLChoiceMap) = false # TODO not necessarily true -has_value(choices::DynamicDSLChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::DynamicDSLChoiceMap, addr::Pair) = _get_value(choices, addr) get_submap(choices::DynamicDSLChoiceMap, addr::Pair) = _get_submap(choices, addr) - function get_submap(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - if has_leaf_node(trie, addr) - # leaf node, must be a call - call = trie[addr] - if call.is_choice - throw(KeyError(addr)) - end - get_choices(call.subtrace_or_retval) - elseif has_internal_node(trie, addr) - # internal node - subtrie = get_internal_node(trie, addr) - DynamicDSLChoiceMap(subtrie) # see below + if haskey(choices.trie.leaf_nodes, addr) + get_choices(choices.trie[addr].subtrace) + elseif haskey(choices.trie.internal_nodes, addr) + DynamicDSLChoiceMap(choices.trie.internal_nodes[addr]) else EmptyChoiceMap() end end -function has_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - has_leaf_node(trie, addr) && trie[addr].is_choice -end - -function get_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - choice = trie[addr] - if !choice.is_choice - throw(KeyError(addr)) - end - choice.subtrace_or_retval -end - -function get_values_shallow(choices::DynamicDSLChoiceMap) - ((key, choice.subtrace_or_retval) - for (key, choice) in get_leaf_nodes(choices.trie) - if choice.is_choice) -end - function get_submaps_shallow(choices::DynamicDSLChoiceMap) - calls_iter = ((key, get_choices(call.subtrace_or_retval)) - for (key, call) in get_leaf_nodes(choices.trie) - if !call.is_choice) - internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) - for (key, trie) in get_internal_nodes(choices.trie)) - Iterators.flatten((calls_iter, internal_nodes_iter)) + leafs = ((key, get_choices(record.subtrace)) for (key, record) in get_leaf_nodes(choices.trie)) + internals = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) + Iterators.flatten((leafs, internals)) end ## Base.getindex ## @@ -176,13 +76,7 @@ end function _getindex(trace::DynamicDSLTrace, trie::Trie, addr::Pair) (first, rest) = addr if haskey(trie.leaf_nodes, first) - choice_or_call = trie.leaf_nodes[first] - if choice_or_call.is_choice - error("Unknown address $addr; random choice at $first") - else - subtrace = choice_or_call.subtrace_or_retval - return subtrace[rest] - end + return trie.leaf_nodes[first].subtrace[rest] elseif haskey(trie.internal_nodes, first) return _getindex(trace, trie.internal_nodes[first], rest) else @@ -192,14 +86,7 @@ end function _getindex(trace::DynamicDSLTrace, trie::Trie, addr) if haskey(trie.leaf_nodes, addr) - choice_or_call = trie.leaf_nodes[addr] - if choice_or_call.is_choice - # the value of the random choice - return choice_or_call.subtrace_or_retval - else - # the return value of the generative function call - return get_retval(choice_or_call.subtrace_or_retval) - end + return get_retval(trie.leaf_nodes[addr].subtrace) else error("No random choice or generative function call at address $addr") end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 3e3605f59..913801a56 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -16,57 +16,6 @@ function GFUpdateState(gen_fn, args, prev_trace, constraints, params) 0., visitor, params, discard) end -function traceat(state::GFUpdateState, dist::Distribution{T}, - args::Tuple, key) where {T} - - local prev_retval::T - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for previous choice at this key - has_previous = has_choice(state.prev_trace, key) - if has_previous - prev_choice = get_choice(state.prev_trace, key) - prev_retval = prev_choice.retval - prev_score = prev_choice.score - end - - # check for constraints at this key - constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) - - # record the previous value as discarded if it is replaced - if constrained && has_previous - set_value!(state.discard, key, prev_retval) - end - - # get return value - if constrained - retval = get_value(state.constraints, key) - elseif has_previous - retval = prev_retval - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # update the weight - if has_previous - state.weight += score - prev_score - elseif constrained - state.weight += score - end - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, args::Tuple, key) where {T,U} @@ -78,7 +27,6 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, visit!(state.visitor, key) # check for constraints at this key - check_no_value(state.constraints, key) constraints = get_submap(state.constraints, key) # get subtrace @@ -119,11 +67,11 @@ function splice(state::GFUpdateState, gen_fn::DynamicDSLFunction, retval end -function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function update_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::EmptySelection) score = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - score += choice_or_call.score + for (key, call) in get_leaf_nodes(prev_trie) + score += call.score end for (key, subtrie) in get_internal_nodes(prev_trie) score += update_delete_recurse(subtrie, EmptySelection()) @@ -131,12 +79,12 @@ function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, score end -function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function update_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::DynamicSelection) score = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) + for (key, call) in get_leaf_nodes(prev_trie) if !(key in visited) - score += choice_or_call.score + score += call.score end end for (key, subtrie) in get_internal_nodes(prev_trie) @@ -146,40 +94,25 @@ function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, score end -function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, - visited::AllSelection) - 0. -end - function add_unvisited_to_discard!(discard::DynamicChoiceMap, visited::DynamicSelection, prev_choices::ChoiceMap) - for (key, value) in get_values_shallow(prev_choices) + for (key, submap) in get_submaps_shallow(prev_choices) + # if key IS in visited, + # the recursive call to update already handled the discard + # for this entire submap; else we need to handle it if !(key in visited) - @assert !has_value(discard, key) @assert isempty(get_submap(discard, key)) - set_value!(discard, key, value) - end - end - for (key, submap) in get_submaps_shallow(prev_choices) - @assert !has_value(discard, key) - if key in visited - # the recursive call to update already handled the discard - # for this entire submap - continue - else subvisited = visited[key] if isempty(subvisited) # none of this submap was visited, so we discard the whole thing - @assert isempty(get_submap(discard, key)) set_submap!(discard, key, submap) else subdiscard = get_submap(discard, key) - add_unvisited_to_discard!( - isempty(subdiscard) ? choicemap() : subdiscard, - subvisited, submap) + subdiscard = isempty(subdiscard) ? choicemap() : subdiscard + add_unvisited_to_discard!(subdiscard, subvisited, submap) set_submap!(discard, key, subdiscard) - end + end end end end diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index 4cab00254..5a07e383f 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -1,12 +1,13 @@ using MacroTools: @capture, postwalk, unblock, rmlines, flatten function check_observations(choices::ChoiceMap, observations::ChoiceMap) - for (key, value) in get_values_shallow(observations) - !has_value(choices, key) && error("Check failed: observed choice at $key not found") - choices[key] != value && error("Check failed: value of observed choice at $key changed") - end for (key, submap) in get_submaps_shallow(observations) - check_observations(get_submap(choices, key), submap) + if has_value(submap) + !has_value(choices, key) && error("Check failed: observed choice at $key not found") + choices[key] != value && error("Check failed: value of observed choice at $key changed") + else + check_observations(get_submap(choices, key), submap) + end end end diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 0c5b997bc..cbd9ce400 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -14,10 +14,7 @@ function get_submap(choices::CallAtChoiceMap{K,T}, addr::K) where {K,T} end get_submap(choices::CallAtChoiceMap, addr::Pair) = _get_submap(choices, addr) -get_value(choices::CallAtChoiceMap, addr::Pair) = _get_value(choices, addr) -has_value(choices::CallAtChoiceMap, addr::Pair) = _has_value(choices, addr) get_submaps_shallow(choices::CallAtChoiceMap) = ((choices.key, choices.submap),) -get_values_shallow(::CallAtChoiceMap) = () # TODO optimize CallAtTrace using type parameters @@ -69,7 +66,7 @@ unpack_call_at_args(args) = (args[end], args[1:end-1]) function assess(gen_fn::CallAtCombinator, args::Tuple, choices::ChoiceMap) (key, kernel_args) = unpack_call_at_args(args) - if length(collect(get_submaps_shallow(choices))) > 1 || !isempty(get_values_shallow(choices)) + if length(collect(get_submaps_shallow(choices))) > 1 error("Not all constraints were consumed") end submap = get_submap(choices, key) @@ -149,12 +146,12 @@ function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad) input_grads = (kernel_input_grads..., nothing) value_choices = CallAtChoiceMap(trace.key, value_submap) gradient_choices = CallAtChoiceMap(trace.key, gradient_submap) - (input_grads, value_choices, gradient_choices) + return (input_grads, value_choices, gradient_choices) end function accumulate_param_gradients!(trace::CallAtTrace, retval_grad) kernel_input_grads = accumulate_param_gradients!(trace.subtrace, retval_grad) - (kernel_input_grads..., nothing) + return (kernel_input_grads..., nothing) end export call_at diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl deleted file mode 100644 index 09cb922fa..000000000 --- a/src/modeling_library/choice_at/choice_at.jl +++ /dev/null @@ -1,175 +0,0 @@ -# TODO optimize ChoiceAtTrace using type parameters - -struct ChoiceAtTrace <: Trace - gen_fn::GenerativeFunction # the ChoiceAtCombinator (not the kernel) - value::Any - key::Any - kernel_args::Tuple - score::Float64 -end - -get_args(trace::ChoiceAtTrace) = (trace.kernel_args..., trace.key) -get_retval(trace::ChoiceAtTrace) = trace.value -get_score(trace::ChoiceAtTrace) = trace.score -get_gen_fn(trace::ChoiceAtTrace) = trace.gen_fn - -struct ChoiceAtChoiceMap{T,K} <: ChoiceMap - key::K - value::T -end - -get_choices(trace::ChoiceAtTrace) = ChoiceAtChoiceMap(trace.key, trace.value) -Base.isempty(::ChoiceAtChoiceMap) = false -function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap} - SingleDynamicKeyAddressSchema() -end -get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr) -has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr) -function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K} - choices.key == addr ? choices.value : throw(KeyError(choices, addr)) -end -get_submaps_shallow(choices::ChoiceAtChoiceMap) = () -get_values_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, choices.value),) - -struct ChoiceAtCombinator{T,K} <: GenerativeFunction{T, ChoiceAtTrace} - dist::Distribution{T} -end - -accepts_output_grad(gen_fn::ChoiceAtCombinator) = has_output_grad(gen_fn.dist) - -# TODO -# accepts_output_grad is true if the return value is dependent on the 'gradient source elements' -# if the random choice itself is not a 'gradient source element' then it is independent (false) -# if the random choice is a 'gradient source element', then the return value is dependent (true) -# we will consider the random choice as a gradient source element if the -# distribution has has_output_grad = true) - -function choice_at(dist::Distribution{T}, ::Type{K}) where {T,K} - ChoiceAtCombinator{T,K}(dist) -end - -unpack_choice_at_args(args) = (args[end], args[1:end-1]) - -function assess(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - value = get_value(choices, key) - weight = logpdf(gen_fn.dist, value, kernel_args...) - (weight, value) -end - -function propose(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - value = random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - choices = ChoiceAtChoiceMap(key, value) - (choices, score, value) -end - -function simulate(gen_fn::ChoiceAtCombinator, args::Tuple) - (key, kernel_args) = unpack_choice_at_args(args) - value = random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - ChoiceAtTrace(gen_fn, value, key, kernel_args, score) -end - -function generate(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - constrained = has_value(choices, key) - value = constrained ? get_value(choices, key) : random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - trace = ChoiceAtTrace(gen_fn, value, key, kernel_args, score) - weight = constrained ? score : 0. - (trace, weight) -end - -function project(trace::ChoiceAtTrace, selection::Selection) - (trace.key in selection) ? trace.score : 0. -end - -function update(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, - choices::ChoiceMap) - (key, kernel_args) = unpack_choice_at_args(args) - key_changed = (key != trace.key) - constrained = has_value(choices, key) - if key_changed && constrained - new_value = get_value(choices, key) - discard = ChoiceAtChoiceMap(trace.key, trace.value) - elseif !key_changed && constrained - new_value = get_value(choices, key) - discard = ChoiceAtChoiceMap(key, trace.value) - elseif !key_changed && !constrained - new_value = trace.value - discard = EmptyChoiceMap() - else - error("New address $key not constrained in update") - end - new_score = logpdf(trace.gen_fn.dist, new_value, kernel_args...) - new_trace = ChoiceAtTrace(trace.gen_fn, new_value, key, kernel_args, new_score) - weight = new_score - trace.score - (new_trace, weight, UnknownChange(), discard) -end - -function regenerate(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, - selection::Selection) - (key, kernel_args) = unpack_choice_at_args(args) - key_changed = (key != trace.key) - selected = key in selection - if !key_changed && selected - new_value = random(trace.gen_fn.dist, kernel_args...) - elseif !key_changed && !selected - new_value = trace.value - elseif key_changed && !selected - new_value = random(trace.gen_fn.dist, kernel_args...) - else - error("Cannot select new address $key in regenerate") - end - new_score = logpdf(trace.gen_fn.dist, new_value, kernel_args...) - if !key_changed && selected - weight = 0. - elseif !key_changed && !selected - weight = new_score - trace.score - elseif key_changed && !selected - weight = 0. - end - new_trace = ChoiceAtTrace(trace.gen_fn, new_value, key, kernel_args, new_score) - (new_trace, weight, UnknownChange()) -end - -function choice_gradients(trace::ChoiceAtTrace, selection::Selection, retval_grad) - if retval_grad != nothing && !has_output_grad(trace.gen_fn.dist) - error("return value gradient not accepted but one was provided") - end - kernel_arg_grads = logpdf_grad(trace.gen_fn.dist, trace.value, trace.kernel_args...) - if trace.key in selection - value_choices = ChoiceAtChoiceMap(trace.key, trace.value) - choice_grad = kernel_arg_grads[1] - if choice_grad == nothing - error("gradient not available for selected choice") - end - if retval_grad != nothing - choice_grad += retval_grad - end - gradient_choices = ChoiceAtChoiceMap(trace.key, choice_grad) - else - value_choices = EmptyChoiceMap() - gradient_choices = EmptyChoiceMap() - end - input_grads = (kernel_arg_grads[2:end]..., nothing) - (input_grads, value_choices, gradient_choices) -end - -function accumulate_param_gradients!(trace::ChoiceAtTrace, retval_grad) - if retval_grad != nothing && !has_output_grad(trace.gen_fn.dist) - error("return value gradient not accepted but one was provided") - end - kernel_arg_grads = logpdf_grad(trace.gen_fn.dist, trace.value, trace.kernel_args...) - (kernel_arg_grads[2:end]..., nothing) -end - -export choice_at diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index 13d6e4880..8b280ea68 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -5,54 +5,6 @@ import Distributions using SpecialFunctions: loggamma, logbeta, digamma -abstract type Distribution{T} end - -""" - val::T = random(dist::Distribution{T}, args...) - -Sample a random choice from the given distribution with the given arguments. -""" -function random end - -""" - lpdf = logpdf(dist::Distribution{T}, value::T, args...) - -Evaluate the log probability (density) of the value. -""" -function logpdf end - -""" - has::Bool = has_output_grad(dist::Distribution) - -Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice. -""" -function has_output_grad end - -""" - grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...) - -Compute the gradient of the logpdf with respect to the value, and each of the arguments. - -If `has_output_grad` returns false, then the first element of the returned tuple is `nothing`. -Otherwise, the first element of the tuple is the gradient with respect to the value. -If the return value of `has_argument_grads` has a false value for at position `i`, then the `i+1`th element of the returned tuple has value `nothing`. -Otherwise, this element contains the gradient with respect to the `i`th argument. -""" -function logpdf_grad end - -is_discrete(::Distribution) = false # default - -# NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl - -get_return_type(::Distribution{T}) where {T} = T - -export Distribution -export random -export logpdf -export logpdf_grad -export has_output_grad -export is_discrete - # built-in distributions include("distributions/distributions.jl") @@ -70,7 +22,6 @@ include("mixture.jl") include("vector.jl") # built-in generative function combinators -include("choice_at/choice_at.jl") include("call_at/call_at.jl") include("map/map.jl") include("unfold/unfold.jl") diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 7b4d1d23d..57fd6ae38 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -84,17 +84,7 @@ function get_submap(choices::RecurseTraceChoiceMap, end end -function get_submap(choices::RecurseTraceChoiceMap, addr::Pair) - _get_submap(choices, addr) -end - -function has_value(choices::RecurseTraceChoiceMap, addr::Pair) - _has_value(choices, addr) -end - -function get_value(choices::RecurseTraceChoiceMap, addr::Pair) - _get_value(choices, addr) -end +get_submap(choices::RecurseTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) get_values_shallow(choices::RecurseTraceChoiceMap) = () @@ -333,6 +323,9 @@ function recurse_unpack_constraints(constraints::ChoiceMap) production_constraints = Dict{Int, Any}() aggregation_constraints = Dict{Int, Any}() for (addr, node) in get_submaps_shallow(constraints) + if has_value(node) + error("Unknown address: $(addr)") + end idx::Int = addr[1] if addr[2] == Val(:production) production_constraints[idx] = node @@ -342,9 +335,6 @@ function recurse_unpack_constraints(constraints::ChoiceMap) error("Unknown address: $addr") end end - if !isempty(get_values_shallow(constraints)) - error("Unknown address: $(first(get_values_shallow(constraints))[1])") - end return (production_constraints, aggregation_constraints) end diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index 02923a776..853fd3aa7 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -19,14 +19,14 @@ function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap) # Add (address, value) to new_choices from prev_choices if address does not occur in choices. for (address, value) in prev_choice_value_iterator - address in keys(choice_value_iterator) && continue + address in map(first, collect(choice_value_iterator)) && continue set_value!(new_choices, address, value) end # Add (address, submap) to new_choices from prev_choices if address does not occur in choices. # If it does, enter a recursive call to update_recurse_merge. for (address, node1) in prev_choice_submap_iterator - if address in keys(choice_submap_iterator) + if address in map(first, collect(choice_submap_iterator)) node2 = get_submap(choices, address) node = update_recurse_merge(node1, node2) set_submap!(new_choices, address, node) diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index f35360b50..e5daa8688 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -92,10 +92,6 @@ end end @inline get_submap(choices::VectorTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) -@inline get_value(choices::VectorTraceChoiceMap, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::VectorTraceChoiceMap, addr::Pair) = _has_value(choices, addr) -@inline get_values_shallow(::VectorTraceChoiceMap) = () - ############################################ # code shared by vector-shaped combinators # diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 8549be4e3..1107380b9 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -19,30 +19,24 @@ maybe_tracked_value_var(node::JuliaNode) = Symbol("$(maybe_tracked_value_prefix) const maybe_tracked_arg_prefix = gensym("maybe_tracked_arg") maybe_tracked_arg_var(node::JuliaNode, i::Int) = Symbol("$(maybe_tracked_arg_prefix)_$(node.name)_$i") -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode) +function fwd_pass!(selected_calls, fwd_marked, node::TrainableParameterNode) # TODO: only need to mark it if we are doing backprop params push!(fwd_marked, node) end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::ArgumentNode) +function fwd_pass!(selected_calls, fwd_marked, node::ArgumentNode) if node.compute_grad push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode) +function fwd_pass!(selected_calls, fwd_marked, node::JuliaNode) if any(input_node in fwd_marked for input_node in node.inputs) push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode) - if node in selected_choices - push!(fwd_marked, node) - end -end - -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode) +function fwd_pass!(selected_calls, fwd_marked, node::GenerativeFunctionCallNode) if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) push!(fwd_marked, node) end @@ -60,15 +54,6 @@ function back_pass!(back_marked, node::JuliaNode) end end -function back_pass!(back_marked, node::RandomChoiceNode) - # the logpdf of every random choice is a SINK - for input_node in node.inputs - push!(back_marked, input_node) - end - # the value of every random choice is in back_marked, since it affects its logpdf - push!(back_marked, node) -end - function back_pass!(back_marked, node::GenerativeFunctionCallNode) # the logpdf of every generative function call is a SINK for input_node in node.inputs @@ -134,24 +119,6 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - - # every random choice is in back_marked, since it affects it logpdf, but - # also possibly due to other downstream usage of the value - @assert node in back_marked - - if node in fwd_marked - # the only way we are fwd_marked is if this choice was selected - - # initialize gradient with respect to the value of the random choice to zero - # it will be a runtime error, thrown here, if there is no zero() method - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) - end -end - function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) # for reference by other nodes during back_codegen! # could performance optimize this away @@ -217,19 +184,19 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: end function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, - node::RandomChoiceNode, logpdf_grad::Symbol) + node::GenerativeFunctionCallNode, logpdf_grad::Symbol) # only evaluate the gradient of the logpdf if we need to if any(input_node in fwd_marked for input_node in node.inputs) || node in fwd_marked args = map((input_node) -> input_node.name, node.inputs) - push!(stmts, :($logpdf_grad = logpdf_grad($(node.dist), $(node.name), $(args...)))) + push!(stmts, :($logpdf_grad = logpdf_grad($(node.generative_function), $(node.name), $(args...)))) end # increment gradients of input nodes that are in fwd_marked for (i, input_node) in enumerate(node.inputs) if input_node in fwd_marked @assert input_node in back_marked # this ensured its gradient will have been initialized - if !has_argument_grads(node.dist)[i] - error("Distribution $(node.dist) does not have logpdf gradient for argument $i") + if !has_argument_grads(node.generative_function)[i] + error("Distribution $(node.generative_function) does not have logpdf gradient for argument $i") end push!(stmts, :($(gradient_var(input_node)) += $logpdf_grad[$(QuoteNode(i+1))])) end @@ -242,31 +209,8 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropTraceMode) - logpdf_grad = gensym("logpdf_grad") - - # backpropagate to the inputs - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) - - # backpropagate to the value (if it was selected) - if node in fwd_marked - if !has_output_grad(node.dist) - error("Distribution $dist does not logpdf gradient for its output value") - end - push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) - end -end - -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropParamsMode) - logpdf_grad = gensym("logpdf_grad") - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) -end - function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropTraceMode) - # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked @@ -325,46 +269,18 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, end end -function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, - selected_calls::Set{GenerativeFunctionCallNode}, +function generate_value_gradient_trie(selected_calls::Set{GenerativeFunctionCallNode}, value_trie::Symbol, gradient_trie::Symbol) - selected_choices_vec = collect(selected_choices) - quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) - leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec) - leaf_gradients = map((node) -> gradient_var(node), selected_choices_vec) - selected_calls_vec = collect(selected_calls) quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec) - internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), + internal_value_choicemaps = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), selected_calls_vec) - internal_gradients = map((node) -> gradient_trie_var(node), selected_calls_vec) - quote - $value_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),))) - $gradient_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),))) - end -end - -function get_selected_choices(::EmptyAddressSchema, ::StaticIR) - Set{RandomChoiceNode}() -end - -function get_selected_choices(::AllAddressSchema, ir::StaticIR) - Set{RandomChoiceNode}(ir.choice_nodes) -end + internal_gradient_choicemaps = map((node) -> gradient_trie_var(node), selected_calls_vec) -function get_selected_choices(schema::StaticAddressSchema, ir::StaticIR) - selected_choice_addrs = Set(keys(schema)) - selected_choices = Set{RandomChoiceNode}() - for node in ir.choice_nodes - if node.addr in selected_choice_addrs - push!(selected_choices, node) - end + quote + $value_trie = StaticChoiceMap(NamedTuple{($(quoted_internal_keys...),)}(($(internal_value_choicemaps...),))) + $gradient_trie = StaticChoiceMap(NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradient_choicemaps...),))) end - selected_choices end function get_selected_calls(::EmptyAddressSchema, ::StaticIR) @@ -372,7 +288,7 @@ function get_selected_calls(::EmptyAddressSchema, ::StaticIR) end function get_selected_calls(::AllAddressSchema, ir::StaticIR) - Set{GenerativeFunctionCallNode}(ir.call_nodes) + Set{GenerativeFunctionCallNode}(ir.call_nodes...) end function get_selected_calls(schema::StaticAddressSchema, ir::StaticIR) @@ -397,13 +313,12 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, end ir = get_ir(gen_fn_type) - selected_choices = get_selected_choices(schema, ir) selected_calls = get_selected_calls(schema, ir) # forward marking pass fwd_marked = Set{StaticIRNode}() for node in ir.nodes - fwd_pass!(selected_choices, selected_calls, fwd_marked, node) + fwd_pass!(selected_calls, fwd_marked, node) end # backward marking pass @@ -432,7 +347,7 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # assemble value_trie and gradient_trie value_trie = gensym("value_trie") gradient_trie = gensym("gradient_trie") - push!(stmts, generate_value_gradient_trie(selected_choices, selected_calls, + push!(stmts, generate_value_gradient_trie(selected_calls, value_trie, gradient_trie)) # gradients with respect to inputs @@ -451,7 +366,7 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, ir = get_ir(gen_fn_type) # unlike choice_gradients we don't take gradients w.r.t. the value of random choices - selected_choices = Set{RandomChoiceNode}() + # selected_choices = Set{GenerativeFunctionCallNode}() # we need to guarantee that we visit every generative function call, # because we need to backpropagate to its trainable parameters @@ -461,7 +376,7 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, # forward marking pass fwd_marked = Set{StaticIRNode}() for node in ir.nodes - fwd_pass!(selected_choices, selected_calls, fwd_marked, node) + fwd_pass!(selected_calls, fwd_marked, node) end # backward marking pass diff --git a/src/static_ir/dag.jl b/src/static_ir/dag.jl index c82658892..de6acf6e7 100644 --- a/src/static_ir/dag.jl +++ b/src/static_ir/dag.jl @@ -18,14 +18,6 @@ struct JuliaNode <: StaticIRNode typ::Union{Symbol,Expr,QuoteNode} end -struct RandomChoiceNode <: StaticIRNode - dist::Distribution - inputs::Vector{StaticIRNode} - addr::Symbol - name::Symbol - typ::Union{Symbol,Expr,QuoteNode} -end - struct GenerativeFunctionCallNode <: StaticIRNode generative_function::GenerativeFunction inputs::Vector{StaticIRNode} @@ -38,7 +30,6 @@ struct StaticIR nodes::Vector{StaticIRNode} trainable_param_nodes::Vector{TrainableParameterNode} arg_nodes::Vector{ArgumentNode} - choice_nodes::Vector{RandomChoiceNode} call_nodes::Vector{GenerativeFunctionCallNode} julia_nodes::Vector{JuliaNode} return_node::StaticIRNode @@ -50,12 +41,10 @@ mutable struct StaticIRBuilder node_set::Set{StaticIRNode} trainable_param_nodes::Vector{TrainableParameterNode} arg_nodes::Vector{ArgumentNode} - choice_nodes::Vector{RandomChoiceNode} call_nodes::Vector{GenerativeFunctionCallNode} julia_nodes::Vector{JuliaNode} return_node::Union{Nothing,StaticIRNode} vars::Set{Symbol} - addrs_to_choice_nodes::Dict{Symbol,RandomChoiceNode} addrs_to_call_nodes::Dict{Symbol,GenerativeFunctionCallNode} accepts_output_grad::Bool end @@ -65,17 +54,15 @@ function StaticIRBuilder() node_set = Set{StaticIRNode}() trainable_param_nodes = Vector{TrainableParameterNode}() arg_nodes = Vector{ArgumentNode}() - choice_nodes = Vector{RandomChoiceNode}() call_nodes = Vector{GenerativeFunctionCallNode}() julia_nodes = Vector{JuliaNode}() return_node = nothing vars = Set{Symbol}() - addrs_to_choice_nodes = Dict{Symbol,RandomChoiceNode}() addrs_to_call_nodes = Dict{Symbol,GenerativeFunctionCallNode}() accepts_output_grad = false - StaticIRBuilder(nodes, node_set, trainable_param_nodes, arg_nodes, choice_nodes, call_nodes, + StaticIRBuilder(nodes, node_set, trainable_param_nodes, arg_nodes, call_nodes, julia_nodes, - return_node, vars, addrs_to_choice_nodes, addrs_to_call_nodes, + return_node, vars, addrs_to_call_nodes, accepts_output_grad) end @@ -87,7 +74,6 @@ function build_ir(builder::StaticIRBuilder) builder.nodes, builder.trainable_param_nodes, builder.arg_nodes, - builder.choice_nodes, builder.call_nodes, builder.julia_nodes, builder.return_node, @@ -109,7 +95,7 @@ function check_inputs_exist(builder::StaticIRBuilder, input_nodes) end function check_addr_unique(builder::StaticIRBuilder, addr::Symbol) - if haskey(builder.addrs_to_choice_nodes, addr) || haskey(builder.addrs_to_call_nodes, addr) + if haskey(builder.addrs_to_call_nodes, addr) error("Address $addr was not unique") end end @@ -164,20 +150,6 @@ function add_constant_node!(builder::StaticIRBuilder, val, node end -function add_addr_node!(builder::StaticIRBuilder, dist::Distribution; - inputs::Vector=[], addr::Symbol=gensym(), - name::Symbol=gensym()) - check_unique_var(builder, name) - check_addr_unique(builder, addr) - check_inputs_exist(builder, inputs) - typ = QuoteNode(get_return_type(dist)) - node = RandomChoiceNode(dist, inputs, addr, name, typ) - _add_node!(builder, node) - builder.addrs_to_choice_nodes[addr] = node - push!(builder.choice_nodes, node) - node -end - function add_addr_node!(builder::StaticIRBuilder, gen_fn::GenerativeFunction; inputs::Vector=[], addr::Symbol=gensym(), name::Symbol=gensym()) diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index 643686766..b6b64253f 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -21,27 +21,6 @@ function process!(state::StaticIRGenerateState, node::JuliaNode, options) end end -function process!(state::StaticIRGenerateState, node::RandomChoiceNode, options) - schema = state.schema - args = map((input_node) -> input_node.name, node.inputs) - incr = gensym("logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) - if isa(schema, StaticAddressSchema) && (node.addr in keys(schema)) - push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :static_get_value))(constraints, Val($addr)))) - push!(state.stmts, :($incr = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(args...)))) - push!(state.stmts, :($weight += $incr)) - else - push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(args...)))) - push!(state.stmts, :($incr = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(args...)))) - end - push!(state.stmts, :($(get_value_fieldname(node)) = $(node.name))) - push!(state.stmts, :($(get_score_fieldname(node)) = $incr)) - push!(state.stmts, :($num_nonempty_fieldname += 1)) - push!(state.stmts, :($total_score_fieldname += $incr)) -end - function process!(state::StaticIRGenerateState, node::GenerativeFunctionCallNode, options) schema = state.schema args = map((input_node) -> input_node.name, node.inputs) diff --git a/src/static_ir/print_ir.jl b/src/static_ir/print_ir.jl index 0a8c52b88..f6291a700 100644 --- a/src/static_ir/print_ir.jl +++ b/src/static_ir/print_ir.jl @@ -27,10 +27,5 @@ function print_ir(io::IO, node::GenerativeFunctionCallNode) print(io, "$(node.name) = @trace($(gen_fn_name)($inputs), :$(node.addr))") end -function print_ir(io::IO, node::RandomChoiceNode) - inputs = join((string(i.name) for i in node.inputs), ", ") - print(io, "$(node.name) = @trace($(node.dist)($inputs), :$(node.addr))") -end - ir_name(fn::GenerativeFunction) = nameof(typeof(fn)) ir_name(fn::DynamicDSLFunction) = nameof(fn.julia_function) diff --git a/src/static_ir/project.jl b/src/static_ir/project.jl index b336d9735..72024b651 100644 --- a/src/static_ir/project.jl +++ b/src/static_ir/project.jl @@ -5,14 +5,6 @@ end function process!(state::StaticIRProjectState, node) end -function process!(state::StaticIRProjectState, node::RandomChoiceNode) - schema = state.schema - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) - if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.stmts, :($weight += trace.$(get_score_fieldname(node)))) - end -end - function process!(state::StaticIRProjectState, node::GenerativeFunctionCallNode) schema = state.schema addr = QuoteNode(node.addr) diff --git a/src/static_ir/render_ir.jl b/src/static_ir/render_ir.jl index 22e7b3625..880fec505 100644 --- a/src/static_ir/render_ir.jl +++ b/src/static_ir/render_ir.jl @@ -1,7 +1,12 @@ label(node::ArgumentNode) = String(node.name) label(node::JuliaNode) = String(node.name) -label(node::RandomChoiceNode) = "$(node.dist) $(node.addr) $(node.name)" -label(node::GenerativeFunctionCallNode) = "$(node.addr) $(node.name)" +function label(node::GenerativeFunctionCallNode) + if node.generative_function isa Distribution + "$(node.generative_function) $(node.addr) $(node.name)" + else + "$(node.addr) $(node.name)" + end +end function draw_graph(ir::StaticIR, graphviz, fname) dot = graphviz.Digraph() @@ -14,7 +19,7 @@ function draw_graph(ir::StaticIR, graphviz, fname) shape = "diamond" color = "white" parents = [] - elseif isa(node, RandomChoiceNode) + elseif isa(node, GenerativeFunctionCallNode) && node.generative_function isa Distribution shape = "ellipse" color = "white" parents = node.inputs diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index 02fb78800..188298525 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -20,19 +20,6 @@ function process!(state::StaticIRSimulateState, node::JuliaNode, options) end end -function process!(state::StaticIRSimulateState, node::RandomChoiceNode, options) - args = map((input_node) -> input_node.name, node.inputs) - incr = gensym("logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(args...)))) - push!(state.stmts, :($incr = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(args...)))) - push!(state.stmts, :($(get_value_fieldname(node)) = $(node.name))) - push!(state.stmts, :($(get_score_fieldname(node)) = $incr)) - push!(state.stmts, :($num_nonempty_fieldname += 1)) - push!(state.stmts, :($total_score_fieldname += $incr)) -end - function process!(state::StaticIRSimulateState, node::GenerativeFunctionCallNode, options) args = map((input_node) -> input_node.name, node.inputs) args_tuple = Expr(:tuple, args...) diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index 3e27810f1..e0eabe1bf 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -52,6 +52,7 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati params_grad::Dict{Symbol,Any} params::Dict{Symbol,Any} end + # Generate accessors $(GlobalRef(Gen, :get_ir))(::$gen_fn_type_name) = $(QuoteNode(ir)) $(GlobalRef(Gen, :get_ir))(::Type{$gen_fn_type_name}) = $(QuoteNode(ir)) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index de2c84b30..a38cfdbd7 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -9,26 +9,8 @@ end function get_schema end @inline get_address_schema(::Type{StaticIRTraceAssmt{T}}) where {T} = get_schema(T) - @inline Base.isempty(choices::StaticIRTraceAssmt) = isempty(choices.trace) - -@inline static_has_value(choices::StaticIRTraceAssmt, key) = false - -@inline function get_value(choices::StaticIRTraceAssmt, key::Symbol) - static_get_value(choices, Val(key)) -end - -@inline function has_value(choices::StaticIRTraceAssmt, key::Symbol) - static_has_value(choices, Val(key)) -end - -@inline function get_submap(choices::StaticIRTraceAssmt, key::Symbol) - static_get_submap(choices, Val(key)) -end -static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() - -@inline get_value(choices::StaticIRTraceAssmt, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::StaticIRTraceAssmt, addr::Pair) = _has_value(choices, addr) +@inline get_submap(choices::StaticIRTraceAssmt, key::Symbol) = static_get_submap(choices, Val(key)) @inline get_submap(choices::StaticIRTraceAssmt, addr::Pair) = _get_submap(choices, addr) ######################### @@ -37,21 +19,21 @@ static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() abstract type StaticIRTrace <: Trace end -@inline function static_get_subtrace(trace::StaticIRTrace, addr) - error("Not implemented") -end +@inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() +@inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false - Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) +@inline Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) -@inline function Base.getindex(trace::StaticIRTrace, addr) - Gen.static_getindex(trace, Val(addr)) -end +@inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_getindex(trace, Val(addr)) @inline function Base.getindex(trace::StaticIRTrace, addr::Pair) first, rest = addr return Gen.static_get_subtrace(trace, Val(first))[rest] end +@inline get_choices(trace::T) where {T <: StaticIRTrace} = StaticIRTraceAssmt{T}(trace) + const arg_prefix = gensym("arg") const choice_value_prefix = gensym("choice_value") const choice_score_prefix = gensym("choice_score") @@ -62,18 +44,10 @@ function get_value_fieldname(node::ArgumentNode) Symbol("$(arg_prefix)_$(node.name)") end -function get_value_fieldname(node::RandomChoiceNode) - Symbol("$(choice_value_prefix)_$(node.addr)") -end - function get_value_fieldname(node::JuliaNode) Symbol("$(julia_prefix)_$(node.name)") end -function get_score_fieldname(node::RandomChoiceNode) - Symbol("$(choice_score_prefix)_$(node.addr)") -end - function get_subtrace_fieldname(node::GenerativeFunctionCallNode) Symbol("$(subtrace_prefix)_$(node.addr)") end @@ -94,12 +68,6 @@ function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptio fieldname = get_value_fieldname(node) push!(fields, TraceField(fieldname, node.typ)) end - for node in ir.choice_nodes - value_fieldname = get_value_fieldname(node) - push!(fields, TraceField(value_fieldname, node.typ)) - score_fieldname = get_score_fieldname(node) - push!(fields, TraceField(score_fieldname, QuoteNode(Float64))) - end for node in ir.call_nodes subtrace_fieldname = get_subtrace_fieldname(node) subtrace_type = QuoteNode(get_trace_type(node.generative_function)) @@ -154,27 +122,6 @@ function generate_get_retval(ir::StaticIR, trace_struct_name::Symbol) Expr(:block, :(trace.$return_value_fieldname))) end -function generate_get_choices(trace_struct_name::Symbol) - Expr(:function, - Expr(:call, GlobalRef(Gen, :get_choices), :(trace::$trace_struct_name)), - Expr(:if, :(!isempty(trace)), - :($(QuoteNode(StaticIRTraceAssmt))(trace)), - :($(QuoteNode(EmptyChoiceMap))()))) -end - -function generate_get_values_shallow(ir::StaticIR, trace_struct_name::Symbol) - elements = [] - for node in ir.choice_nodes - addr = node.addr - value = :(choices.trace.$(get_value_fieldname(node))) - push!(elements, :(($(QuoteNode(addr)), $value))) - end - Expr(:function, - Expr(:call, GlobalRef(Gen, :get_values_shallow), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name})), - Expr(:block, Expr(:tuple, elements...))) -end - function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] for node in ir.call_nodes @@ -210,43 +157,8 @@ function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) end ) end - - choice_getindex_exprs = Expr[] - for node in ir.choice_nodes - push!(choice_getindex_exprs, - quote - function $(GlobalRef(Gen, :static_getindex))(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) - return trace.$(get_value_fieldname(node)) - end - end - ) - end - - return [get_subtrace_exprs; call_getindex_exprs; choice_getindex_exprs] -end - -function generate_static_get_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, GlobalRef(Gen, :static_get_value), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(choices.trace.$(get_value_fieldname(node)))))) - end - methods -end - -function generate_static_has_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, GlobalRef(Gen, :static_has_value), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(true)))) - end - methods + + return [get_subtrace_exprs; call_getindex_exprs] end function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) @@ -260,21 +172,11 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) :($(GlobalRef(Gen, :get_choices))(choices.trace.$(get_subtrace_fieldname(node))))))) end - # throw a KeyError if get_submap is run on an address containing a value - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, GlobalRef(Gen, :static_get_submap), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(throw(KeyError($(QuoteNode(node.addr)))))))) - end methods end function generate_get_schema(ir::StaticIR, trace_struct_name::Symbol) - choice_addrs = [QuoteNode(node.addr) for node in ir.choice_nodes] - call_addrs = [QuoteNode(node.addr) for node in ir.call_nodes] - addrs = vcat(choice_addrs, call_addrs) + addrs = [QuoteNode(node.addr) for node in ir.call_nodes] Expr(:function, Expr(:call, GlobalRef(Gen, :get_schema), :(::Type{$trace_struct_name})), Expr(:block, @@ -289,20 +191,14 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_score_expr = generate_get_score(trace_struct_name) get_args_expr = generate_get_args(ir, trace_struct_name) get_retval_expr = generate_get_retval(ir, trace_struct_name) - get_choices_expr = generate_get_choices(trace_struct_name) get_schema_expr = generate_get_schema(ir, trace_struct_name) - get_values_shallow_expr = generate_get_values_shallow(ir, trace_struct_name) get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name) - static_get_value_exprs = generate_static_get_value(ir, trace_struct_name) - static_has_value_exprs = generate_static_has_value(ir, trace_struct_name) static_get_submap_exprs = generate_static_get_submap(ir, trace_struct_name) getindex_exprs = generate_getindex(ir, trace_struct_name) exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr, get_args_expr, get_retval_expr, - get_choices_expr, get_schema_expr, get_values_shallow_expr, - get_submaps_shallow_expr, static_get_value_exprs..., - static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...) + get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) (exprs, trace_struct_name) end diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index 5768791c1..64b5fb101 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -9,7 +9,6 @@ const calldiff_prefix = gensym("calldiff") calldiff_var(node::GenerativeFunctionCallNode) = Symbol("$(calldiff_prefix)_$(node.addr)") const choice_discard_prefix = gensym("choice_discard") -choice_discard_var(node::RandomChoiceNode) = Symbol("$(choice_discard_prefix)_$(node.addr)") const call_discard_prefix = gensym("call_discard") call_discard_var(node::GenerativeFunctionCallNode) = Symbol("$(call_discard_prefix)_$(node.addr)") @@ -19,21 +18,18 @@ call_discard_var(node::GenerativeFunctionCallNode) = Symbol("$(call_discard_pref ######################## struct ForwardPassState - input_changed::Set{Union{RandomChoiceNode,GenerativeFunctionCallNode}} + input_changed::Set{GenerativeFunctionCallNode} value_changed::Set{StaticIRNode} - constrained_or_selected_choices::Set{RandomChoiceNode} constrained_or_selected_calls::Set{GenerativeFunctionCallNode} discard_calls::Set{GenerativeFunctionCallNode} end function ForwardPassState() - input_changed = Set{Union{RandomChoiceNode,GenerativeFunctionCallNode}}() + input_changed = Set{GenerativeFunctionCallNode}() value_changed = Set{StaticIRNode}() - constrained_or_selected_choices = Set{RandomChoiceNode}() constrained_or_selected_calls = Set{GenerativeFunctionCallNode}() discard_calls = Set{GenerativeFunctionCallNode}() - ForwardPassState(input_changed, value_changed, constrained_or_selected_choices, - constrained_or_selected_calls, discard_calls) + ForwardPassState(input_changed, value_changed, constrained_or_selected_calls, discard_calls) end function forward_pass_argdiff!(state::ForwardPassState, @@ -46,40 +42,53 @@ function forward_pass_argdiff!(state::ForwardPassState, end end -function process_forward!(::AddressSchema, ::ForwardPassState, ::TrainableParameterNode) end +function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, ::ForwardPassState, ::TrainableParameterNode) end -function process_forward!(::AddressSchema, ::ForwardPassState, node::ArgumentNode) end +function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, ::ForwardPassState, node::ArgumentNode) end -function process_forward!(::AddressSchema, state::ForwardPassState, node::JuliaNode) +function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, state::ForwardPassState, node::JuliaNode) if any(input_node in state.value_changed for input_node in node.inputs) push!(state.value_changed, node) end end -function process_forward!(schema::AddressSchema, state::ForwardPassState, - node::RandomChoiceNode) - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) - if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.constrained_or_selected_choices, node) - push!(state.value_changed, node) - end - if any(input_node in state.value_changed for input_node in node.inputs) - push!(state.input_changed, node) - end +function cannot_statically_guarantee_nochange_retdiff(constraint_type, node, state) + update_fn = constraint_type <: ChoiceMap ? Gen.update : Gen.regenerate + + trace_type = get_trace_type(node.generative_function) + argdiff_types = map(input_node -> input_node in state.value_changed ? UnknownChange : NoChange, node.inputs) + argdiff_type = Tuple{argdiff_types...} + # TODO: can we know the arg type statically? + update_rettype = Core.Compiler.return_type( + update_fn, + Tuple{trace_type, Tuple, argdiff_type, constraint_type} + ) + has_static_retdiff = update_rettype <: Tuple && update_rettype != Union{} && length(update_rettype.parameters) >= 3 + guaranteed_returns_nochange = has_static_retdiff && update_rettype.parameters[3] == NoChange + + # println("$trace_type, Tuple, $argdiff_type, $constraint_type >> $update_rettype : $has_static_retdiff") + + return !guaranteed_returns_nochange end -function process_forward!(schema::AddressSchema, state::ForwardPassState, +function process_forward!(constraint_type::Type{<:Union{<:ChoiceMap, <:Selection}}, state::ForwardPassState, node::GenerativeFunctionCallNode) + schema = get_address_schema(constraint_type) + will_run_update = false @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) push!(state.constrained_or_selected_calls, node) - push!(state.value_changed, node) - push!(state.discard_calls, node) + will_run_update = true end if any(input_node in state.value_changed for input_node in node.inputs) push!(state.input_changed, node) - push!(state.value_changed, node) # TODO can check whether the node is satically absorbing + will_run_update = true + end + if will_run_update push!(state.discard_calls, node) + if cannot_statically_guarantee_nochange_retdiff(constraint_type, node, state) + push!(state.value_changed, node) + end end end @@ -113,15 +122,6 @@ function process_backward!(fwd::ForwardPassState, back::BackwardPassState, end end -function process_backward!(fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, options) - if node in fwd.input_changed || node in fwd.constrained_or_selected_choices - for input_node in node.inputs - push!(back.marked, input_node) - end - end -end - function process_backward!(fwd::ForwardPassState, back::BackwardPassState, node::GenerativeFunctionCallNode, options) if node in fwd.input_changed || node in fwd.constrained_or_selected_calls @@ -189,118 +189,6 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, end end -function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, ::UpdateMode, - options) - if options.track_diffs - - # track diffs - arg_values, _ = arg_values_and_diffs_from_tracked_diffs(node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :static_get_value))(constraints, Val($addr)), $(GlobalRef(Gen, :UnknownChange))()))) - push!(stmts, :($(choice_discard_var(node)) = trace.$(get_value_fieldname(node)))) - else - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), NoChange()))) - end - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $(GlobalRef(Gen, :strip_diff))($(node.name)), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), $(GlobalRef(Gen, :NoChange))()))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - - else - - # no track diffs - arg_values = map((n) -> n.name, node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :static_get_value))(constraints, Val($addr)))) - push!(stmts, :($(choice_discard_var(node)) = trace.$(get_value_fieldname(node)))) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - end - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - end -end - -function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, ::RegenerateMode, - options) - if options.track_diffs - - # track diffs - arg_values, _ = arg_values_and_diffs_from_tracked_diffs(node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - output_value = Expr(:call, (GlobalRef(Gen, :strip_diff)), node.name) - if node in fwd.constrained_or_selected_choices - # the choice was selected, it does not contribute to the weight - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :random))($dist, $(arg_values...)), UnknownChange()))) - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $output_value, $(arg_values...)))) - else - # the choice was not selected, and the input to the choice changed - # it does contribute to the weight - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), NoChange()))) - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $output_value, $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), NoChange()))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - else - - # no track diffs - arg_values = map((n) -> n.name, node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - # the choice was selected, it does not contribute to the weight - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(arg_values...)))) - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(arg_values...)))) - else - # the choice was not selected, and the input to the choice changed - # it does contribute to the weight - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - end -end - function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, node::GenerativeFunctionCallNode, ::UpdateMode, options) @@ -431,32 +319,20 @@ function generate_new_trace!(stmts::Vector{Expr}, trace_type::Type, options) end end -function generate_discard!(stmts::Vector{Expr}, - constrained_choices::Set{RandomChoiceNode}, - discard_calls::Set{GenerativeFunctionCallNode}) - discard_leaf_nodes = Dict{Symbol,Symbol}() - for node in constrained_choices - discard_leaf_nodes[node.addr] = choice_discard_var(node) - end - discard_internal_nodes = Dict{Symbol,Symbol}() +function generate_discard!(stmts::Vector{Expr}, discard_calls::Set{GenerativeFunctionCallNode}) + discard_nodes = Dict{Symbol,Symbol}() for node in discard_calls - discard_internal_nodes[node.addr] = call_discard_var(node) - end - if length(discard_leaf_nodes) > 0 - (leaf_keys, leaf_nodes) = collect(zip(discard_leaf_nodes...)) - else - (leaf_keys, leaf_nodes) = ((), ()) + discard_nodes[node.addr] = call_discard_var(node) end - if length(discard_internal_nodes) > 0 - (internal_keys, internal_nodes) = collect(zip(discard_internal_nodes...)) + + if length(discard_nodes) > 0 + (keys, nodes) = collect(zip(discard_nodes...)) else - (internal_keys, internal_nodes) = ((), ()) + (keys, nodes) = ((), ()) end - leaf_keys = map((key::Symbol) -> QuoteNode(key), leaf_keys) - internal_keys = map((key::Symbol) -> QuoteNode(key), internal_keys) - expr = :($(QuoteNode(StaticChoiceMap))( - $(QuoteNode(NamedTuple)){($(leaf_keys...),)}(($(leaf_nodes...),)), - $(QuoteNode(NamedTuple)){($(internal_keys...),)}(($(internal_nodes...),)))) + keys = map((key::Symbol) -> QuoteNode(key), keys) + expr = quote $(QuoteNode(StaticChoiceMap))( + $(QuoteNode(NamedTuple)){($(keys...),)}(($(nodes...),))) end push!(stmts, :($discard = $expr)) end @@ -481,7 +357,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ fwd_state = ForwardPassState() forward_pass_argdiff!(fwd_state, ir.arg_nodes, argdiffs_type) for node in ir.nodes - process_forward!(schema, fwd_state, node) + process_forward!(constraints_type, fwd_state, node) end # backward marking pass @@ -504,7 +380,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ end generate_return_value!(stmts, fwd_state, ir.return_node, options) generate_new_trace!(stmts, trace_type, options) - generate_discard!(stmts, fwd_state.constrained_or_selected_choices, fwd_state.discard_calls) + generate_discard!(stmts, fwd_state.discard_calls) # return trace and weight and discard and retdiff push!(stmts, :(return ($trace, $weight, $retdiff, $discard))) @@ -529,7 +405,7 @@ function codegen_regenerate(trace_type::Type{T}, args_type::Type, argdiffs_type: fwd_state = ForwardPassState() forward_pass_argdiff!(fwd_state, ir.arg_nodes, argdiffs_type) for node in ir.nodes - process_forward!(schema, fwd_state, node) + process_forward!(selection_type, fwd_state, node) end # backward marking pass diff --git a/test/assignment.jl b/test/assignment.jl index dcaa237fa..8ec08f992 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -1,6 +1,48 @@ +@testset "ValueChoiceMap" begin + vcm1 = ValueChoiceMap(2) + vcm2 = ValueChoiceMap(2.) + vcm3 = ValueChoiceMap([1,2]) + @test vcm1 isa ValueChoiceMap{Int} + @test vcm2 isa ValueChoiceMap{Float64} + @test vcm3 isa ValueChoiceMap{Vector{Int}} + @test vcm1[] == 2 + @test vcm1[] == get_value(vcm1) + + @test !isempty(vcm1) + @test has_value(vcm1) + @test get_value(vcm1) == 2 + @test vcm1 == vcm2 + @test isempty(get_submaps_shallow(vcm1)) + @test isempty(get_values_shallow(vcm1)) + @test isempty(get_nonvalue_submaps_shallow(vcm1)) + @test to_array(vcm1, Int) == [2] + @test from_array(vcm1, [4]) == ValueChoiceMap(4) + @test from_array(vcm3, [4, 5]) == ValueChoiceMap([4, 5]) + @test_throws Exception merge(vcm1, vcm2) + @test_throws Exception merge(vcm1, choicemap(:a, 5)) + @test merge(vcm1, EmptyChoiceMap()) == vcm1 + @test merge(EmptyChoiceMap(), vcm1) == vcm1 + @test get_submap(vcm1, :addr) == EmptyChoiceMap() + @test_throws ChoiceMapGetValueError get_value(vcm1, :addr) + @test !has_value(vcm1, :addr) + @test isapprox(vcm2, ValueChoiceMap(prevfloat(2.))) + @test isapprox(vcm1, ValueChoiceMap(prevfloat(2.))) + @test get_address_schema(typeof(vcm1)) == AllAddressSchema() + @test get_address_schema(ValueChoiceMap) == AllAddressSchema() + @test nested_view(vcm1) == 2 +end + +@testset "static choicemap constructor" begin + @test StaticChoiceMap((a=ValueChoiceMap(5), b=ValueChoiceMap(6))) == StaticChoiceMap(a=5, b=6) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + @test submap == StaticChoiceMap((a=ValueChoiceMap(1.), b=ValueChoiceMap([2., 2.5]))) + outer = StaticChoiceMap(c=3, d=submap, e=submap) + @test outer == StaticChoiceMap((c=ValueChoiceMap(3), d=submap, e=submap)) +end + @testset "static assignment to/from array" begin - submap = StaticChoiceMap((a=1., b=[2., 2.5]),NamedTuple()) - outer = StaticChoiceMap((c=3.,), (d=submap, e=submap)) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + outer = StaticChoiceMap(c=3., d=submap, e=submap) arr = to_array(outer, Float64) @test to_array(outer, Float64) == Float64[3.0, 1.0, 2.0, 2.5, 1.0, 2.0, 2.5] @@ -11,14 +53,16 @@ @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment to/from array" begin @@ -39,14 +83,18 @@ end @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test get_submap(choices, :c) == ValueChoiceMap(1.0) + @test get_submap(choices, :d => :b) == ValueChoiceMap([3.0, 4.0]) + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment copy constructor" begin @@ -64,25 +112,6 @@ end @test choices[:u => :w] == 4 end -@testset "internal vector assignment to/from array" begin - inner = choicemap() - set_value!(inner, :a, 1.) - set_value!(inner, :b, 2.) - outer = vectorize_internal([inner, inner, inner]) - - arr = to_array(outer, Float64) - @test to_array(outer, Float64) == Float64[1, 2, 1, 2, 1, 2] - - choices = from_array(outer, Float64[1, 2, 3, 4, 5, 6]) - @test choices[1 => :a] == 1.0 - @test choices[1 => :b] == 2.0 - @test choices[2 => :a] == 3.0 - @test choices[2 => :b] == 4.0 - @test choices[3 => :a] == 5.0 - @test choices[3 => :b] == 6.0 - @test length(collect(get_submaps_shallow(choices))) == 3 -end - @testset "dynamic assignment merge" begin submap = choicemap() set_value!(submap, :x, 1) @@ -107,7 +136,7 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @@ -125,8 +154,8 @@ end set_value!(submap, :x, 1) submap2 = choicemap() set_value!(submap2, :y, 4.) - choices1 = StaticChoiceMap((a=1., b=2.), (c=submap, shared=submap)) - choices2 = StaticChoiceMap((d=3.,), (e=submap, f=submap, shared=submap2)) + choices1 = StaticChoiceMap(a=1., b=2., c=submap, shared=submap) + choices2 = StaticChoiceMap(d=3., e=submap, f=submap, shared=submap2) choices = merge(choices1, choices2) @test choices[:a] == 1. @test choices[:b] == 2. @@ -136,124 +165,91 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @testset "static assignment variadic merge" begin - choices1 = StaticChoiceMap((a=1,), NamedTuple()) - choices2 = StaticChoiceMap((b=2,), NamedTuple()) - choices3 = StaticChoiceMap((c=3,), NamedTuple()) - choices_all = StaticChoiceMap((a=1, b=2, c=3), NamedTuple()) + choices1 = StaticChoiceMap(a=1) + choices2 = StaticChoiceMap(b=2) + choices3 = StaticChoiceMap(c=3) + choices_all = StaticChoiceMap(a=1, b=2, c=3) @test merge(choices1) == choices1 @test merge(choices1, choices2, choices3) == choices_all end +# TODO: in changing a lot of these to reflect the new behavior of choicemap, +# they are mostly not error checks, but instead checks for returning `EmptyChoiceMap`; +# should we relabel this testset? @testset "static assignment errors" begin + # get_choices on an address that returns a ValueChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x) == ValueChoiceMap(1) - # get_choices on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw - - # static_get_submap on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # static_get_submap on an address that contains a value returns a ValueChoiceMap + choices = StaticChoiceMap(x=1) + @test static_get_submap(choices, Val(:x)) == ValueChoiceMap(1) - # get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # get_submap on an address whose prefix contains a value returns EmptyChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) + choices = StaticChoiceMap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # static_get_choices on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # static_get_choices on an address that contains nothing returns an EmptyChoiceMap + choices = StaticChoiceMap() + @test static_get_submap(choices, Val(:x)) == EmptyChoiceMap() - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # static_get_value on an address that contains a submap throws a KeyError + # static_get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) - # get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw + # get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) - # static_get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw + # static_get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) end @testset "dynamic assignment errors" begin - - # get_choices on an address that contains a value throws a KeyError + # get_choices on an address that contains a value returns a ValueChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == ValueChoiceMap(1) - # get_choices on an address whose prefix contains a value throws a KeyError + # get_choices on an address whose prefix contains a value returns EmptyChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment choices = choicemap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError choices = choicemap() choices[:x => :y] = 1 - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # get_value on an address that contains nothing throws a KeyError + # get_value on an address that contains nothing throws a ChoiceMapGetValueError choices = choicemap() - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) end @testset "dynamic assignment overwrite" begin @@ -268,7 +264,8 @@ end choices = choicemap() choices[:x] = 1 @test has_value(choices, :x) - @test !has_submap(choices, :x) + @test has_submap(choices, :x) + @test !has_submap(choices, :z) submap = choicemap(); submap[:y] = 2 set_submap!(choices, :x, submap) @test !has_value(choices, :x) @@ -279,17 +276,15 @@ end choices = choicemap() choices[:x => :y] = 1 choices[:x] = 2 - threw = false - @test !has_submap(choices, :x) - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == ValueChoiceMap(2) @test choices[:x] == 2 # overwrite subassignment with a subassignment choices = choicemap() choices[:x => :y] = 1 @test has_submap(choices, :x) - @test !has_submap(choices, :x => :y) + @test has_submap(choices, :x => :y) # valuechoicemap + @test !has_submap(choices, :x => :z) submap = choicemap(); submap[:z] = 2 set_submap!(choices, :x, submap) @test !isempty(get_submap(choices, :x)) @@ -299,17 +294,13 @@ end # illegal set value under existing value choices = choicemap() choices[:x] = 1 - threw = false - try set_value!(choices, :x => :y, 2) catch KeyError threw = true end - @test threw + @test_throws Exception set_value!(choices, :x => :y, 2) # illegal set submap under existing value choices = choicemap() choices[:x] = 1 submap = choicemap(); choices[:z] = 2 - threw = false - try set_submap!(choices, :x => :y, submap) catch KeyError threw = true end - @test threw + @test_throws Exception set_submap!(choices, :x => :y, submap) end @testset "dynamic assignment constructor" begin diff --git a/test/benchmarks/dataset.jl b/test/benchmarks/dataset.jl new file mode 100644 index 000000000..5cd6308a7 --- /dev/null +++ b/test/benchmarks/dataset.jl @@ -0,0 +1,20 @@ + +function make_data_set(n) + Random.seed!(1) + prob_outlier = 0.5 + true_inlier_noise = 0.5 + true_outlier_noise = 5.0 + true_slope = -1 + true_intercept = 2 + xs = collect(range(-5, stop=5, length=n)) + ys = Float64[] + for (i, x) in enumerate(xs) + if rand() < prob_outlier + y = true_slope * x + true_intercept + randn() * true_inlier_noise + else + y = true_slope * x + true_intercept + randn() * true_outlier_noise + end + push!(ys, y) + end + (xs, ys) +end \ No newline at end of file diff --git a/test/benchmarks/dynamic_mh.jl b/test/benchmarks/dynamic_mh.jl new file mode 100644 index 000000000..fdaaf54f9 --- /dev/null +++ b/test/benchmarks/dynamic_mh.jl @@ -0,0 +1,105 @@ +module DynamicMHBenchmark +using Gen +import Random + +include("dataset.jl") + +######### +# model # +######### + +# TODO put this into FunctionalCollections: +import FunctionalCollections +Base.IndexStyle(::Type{<:FunctionalCollections.PersistentVector}) = IndexLinear() + +@gen function datum(x::Float64, (grad)(inlier_std::Float64), (grad)(outlier_std), (grad)(slope), (grad)(intercept))::Float64 + is_outlier = @trace(bernoulli(0.5), :z) + std = is_outlier ? inlier_std : outlier_std + y = @trace(normal(x * slope + intercept, std), :y) + return y +end + +data = Map(datum) + +@gen function model(xs::Vector{Float64}) + n = length(xs) + inlier_std = exp(@trace(normal(0, 2), :log_inlier_std)) + outlier_std = exp(@trace(normal(0, 2), :log_outlier_std)) + slope = @trace(normal(0, 2), :slope) + intercept = @trace(normal(0, 2), :intercept) + ys = @trace(data(xs, fill(inlier_std, n), fill(outlier_std, n), + fill(slope, n), fill(intercept, n)), :data) + return ys +end + + +@gen function slope_proposal(trace) + slope = trace[:slope] + @trace(normal(slope, 0.5), :slope) +end + +@gen function intercept_proposal(trace) + intercept = trace[:intercept] + @trace(normal(intercept, 0.5), :intercept) +end + +@gen function inlier_std_proposal(trace) + log_inlier_std = trace[:log_inlier_std] + @trace(normal(log_inlier_std, 0.5), :log_inlier_std) +end + +@gen function outlier_std_proposal(trace) + log_outlier_std = trace[:log_outlier_std] + @trace(normal(log_outlier_std, 0.5), :log_outlier_std) +end + +@gen function is_outlier_proposal(trace, i::Int) + prev = trace[:data => i => :z] + @trace(bernoulli(prev ? 0.0 : 1.0), :data => i => :z) +end + +function do_inference(xs, ys, num_iters) + observations = choicemap() + for (i, y) in enumerate(ys) + observations[:data => i => :y] = y + end + + # initial trace + (trace, _) = generate(model, (xs,), observations) + + scores = Vector{Float64}(undef, num_iters) + for i=1:num_iters + + # steps on the parameters + for j=1:5 + (trace, _) = metropolis_hastings(trace, slope_proposal, ()) + (trace, _) = metropolis_hastings(trace, intercept_proposal, ()) + (trace, _) = metropolis_hastings(trace, inlier_std_proposal, ()) + (trace, _) = metropolis_hastings(trace, outlier_std_proposal, ()) + end + + # step on the outliers + for j=1:length(xs) + (trace, _) = metropolis_hastings(trace, is_outlier_proposal, (j,)) + end + + score = get_score(trace) + scores[i] = score + + # print + slope = trace[:slope] + intercept = trace[:intercept] + inlier_std = exp(trace[:log_inlier_std]) + outlier_std = exp(trace[:log_outlier_std]) + end + return scores +end + +println("Simple dynamic DSL MH on regression model:") +(xs, ys) = make_data_set(200) +do_inference(xs, ys, 10) +@time do_inference(xs, ys, 20) +@time do_inference(xs, ys, 20) +println() + +end diff --git a/test/benchmarks/run_benchmarks.jl b/test/benchmarks/run_benchmarks.jl new file mode 100644 index 000000000..13a4754e8 --- /dev/null +++ b/test/benchmarks/run_benchmarks.jl @@ -0,0 +1,2 @@ +include("static_mh.jl") +include("dynamic_mh.jl") \ No newline at end of file diff --git a/test/benchmarks/static_mh.jl b/test/benchmarks/static_mh.jl new file mode 100644 index 000000000..6ae1545d2 --- /dev/null +++ b/test/benchmarks/static_mh.jl @@ -0,0 +1,107 @@ +module StaticMHBenchmark +using Gen +import Random +using Gen: ifelse + +include("dataset.jl") + +@gen (static) function slope_proposal(trace) + slope = trace[:slope] + @trace(normal(slope, 0.5), :slope) +end + +@gen (static) function intercept_proposal(trace) + intercept = trace[:intercept] + @trace(normal(intercept, 0.5), :intercept) +end + +@gen (static) function inlier_std_proposal(trace) + log_inlier_std = trace[:log_inlier_std] + @trace(normal(log_inlier_std, 0.5), :log_inlier_std) +end + +@gen (static) function outlier_std_proposal(trace) + log_outlier_std = trace[:log_outlier_std] + @trace(normal(log_outlier_std, 0.5), :log_outlier_std) +end + +@gen (static) function flip_z(z::Bool) + @trace(bernoulli(z ? 0.0 : 1.0), :z) +end + +@gen (static) function is_outlier_proposal(trace, i::Int) + prev_z = trace[:data => i => :z] + @trace(bernoulli(prev_z ? 0.0 : 1.0), :data => i => :z) +end + +@gen (static) function is_outlier_proposal(trace, i::Int) + prev_z = trace[:data => i => :z] + @trace(bernoulli(prev_z ? 0.0 : 1.0), :data => i => :z) +end + +@gen (static) function datum(x::Float64, (grad)(inlier_std::Float64), (grad)(outlier_std::Float64), + (grad)(slope::Float64), (grad)(intercept::Float64)) + is_outlier = @trace(bernoulli(0.5), :z) + std = ifelse(is_outlier, inlier_std, outlier_std) + y = @trace(normal(x * slope + intercept, std), :y) + return y +end + +data = Map(datum) + +@gen (static) function model(xs::Vector{Float64}) + n = length(xs) + inlier_log_std = @trace(normal(0, 2), :log_inlier_std) + outlier_log_std = @trace(normal(0, 2), :log_outlier_std) + inlier_std = exp(inlier_log_std) + outlier_std = exp(outlier_log_std) + slope = @trace(normal(0, 2), :slope) + intercept = @trace(normal(0, 2), :intercept) + @trace(data(xs, fill(inlier_std, n), fill(outlier_std, n), + fill(slope, n), fill(intercept, n)), :data) +end + +function do_inference(xs, ys, num_iters) + observations = choicemap() + for (i, y) in enumerate(ys) + observations[:data => i => :y] = y + end + + # initial trace + (trace, _) = generate(model, (xs,), observations) + + scores = Vector{Float64}(undef, num_iters) + for i=1:num_iters + + # steps on the parameters + for j=1:5 + (trace, _) = metropolis_hastings(trace, slope_proposal, ()) + (trace, _) = metropolis_hastings(trace, intercept_proposal, ()) + (trace, _) = metropolis_hastings(trace, inlier_std_proposal, ()) + (trace, _) = metropolis_hastings(trace, outlier_std_proposal, ()) + end + + # step on the outliers + for j=1:length(xs) + (trace, _) = metropolis_hastings(trace, is_outlier_proposal, (j,)) + end + + score = get_score(trace) + scores[i] = score + + # print + slope = trace[:slope] + intercept = trace[:intercept] + inlier_std = exp(trace[:log_inlier_std]) + outlier_std = exp(trace[:log_outlier_std]) + end + return scores +end + +(xs, ys) = make_data_set(200) +do_inference(xs, ys, 10) +println("Simple static DSL (including CallAt nodes) MH on regression model:") +@time do_inference(xs, ys, 50) +@time do_inference(xs, ys, 50) +println() +end diff --git a/test/dsl/dynamic_dsl.jl b/test/dsl/dynamic_dsl.jl index 25e3a50f4..a7e80f68d 100644 --- a/test/dsl/dynamic_dsl.jl +++ b/test/dsl/dynamic_dsl.jl @@ -124,7 +124,7 @@ end @test get_value(discard, :x) == x @test get_value(discard, :u => :a) == a @test length(collect(get_values_shallow(discard))) == 2 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 # test new trace new_assignment = get_choices(new_trace) @@ -132,7 +132,7 @@ end @test get_value(new_assignment, :y) == y @test get_value(new_assignment, :v => :b) == b @test length(collect(get_values_shallow(new_assignment))) == 2 - @test length(collect(get_submaps_shallow(new_assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(new_assignment))) == 1 # test score and weight prev_score = ( @@ -247,7 +247,7 @@ end @test !isempty(get_submap(assignment, :v)) end @test length(collect(get_values_shallow(assignment))) == 2 - @test length(collect(get_submaps_shallow(assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(assignment))) == 1 # test weight if assignment[:branch] == prev_assignment[:branch] @@ -337,11 +337,11 @@ end @test get_value(choices, :out) == out @test get_value(choices, :bar => :z) == z @test !has_value(choices, :b) # was not selected - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(collect(get_values_shallow(choices))) == 2 # check gradient trie - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(gradients))) == 2 @test !has_value(gradients, :b) # was not selected @test isapprox(get_value(gradients, :bar => :z), @@ -436,14 +436,14 @@ end @test choices[:x => 2] == 2 @test choices[:x => 3 => :z] == 3 @test length(collect(get_values_shallow(choices))) == 1 # :y - @test length(collect(get_submaps_shallow(choices))) == 1 # :x + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # :x submap = get_submap(choices, :x) @test submap[1] == 1 @test submap[2] == 2 @test submap[3 => :z] == 3 @test length(collect(get_values_shallow(submap))) == 2 # 1, 2 - @test length(collect(get_submaps_shallow(submap))) == 1 # 3 + @test length(collect(get_nonvalue_submaps_shallow(submap))) == 1 # 3 bar_submap = get_submap(submap, 3) @test bar_submap[:z] == 3 diff --git a/test/dsl/static_dsl.jl b/test/dsl/static_dsl.jl index d4bd80280..0d67cf0c4 100644 --- a/test/dsl/static_dsl.jl +++ b/test/dsl/static_dsl.jl @@ -40,13 +40,13 @@ end ret = @trace(bernoulli(0.5), :x => i) end -# @trace(choice_at(bernoulli)(0.5, i), :x) +# @trace(call_at(bernoulli)(0.5, i), :x) @gen (static) function at_choice_example_2(i::Int) ret = @trace(bernoulli(0.5), :x => i => :y) end -# @trace(call_at(choice_at(bernoulli))(0.5, i, :y), :x) +# @trace(call_at(call_at(bernoulli))(0.5, i, :y), :x) @gen function foo(mu) @trace(normal(mu, 1), :y) @@ -135,14 +135,13 @@ params = ir.arg_nodes[2] @test params.compute_grad # choice nodes and call nodes -@test length(ir.choice_nodes) == 2 -@test length(ir.call_nodes) == 0 +@test length(ir.call_nodes) == 2 # is_outlier -is_outlier = ir.choice_nodes[1] +is_outlier = ir.call_nodes[1] @test is_outlier.addr == :z @test is_outlier.typ == QuoteNode(Bool) -@test is_outlier.dist == bernoulli +@test is_outlier.generative_function == bernoulli @test length(is_outlier.inputs) == 1 # std @@ -156,10 +155,10 @@ in2 = std.inputs[2] @test (in1 === is_outlier && in2 === params) || (in2 === is_outlier && in1 === params) # y -y = ir.choice_nodes[2] +y = ir.call_nodes[2] @test y.addr == :y @test y.typ == QuoteNode(Float64) -@test y.dist == normal +@test y.generative_function == normal @test length(y.inputs) == 2 @test y.inputs[2] === std @@ -192,40 +191,39 @@ xs = ir.arg_nodes[1] @test xs.typ == :(Vector{Float64}) @test !xs.compute_grad -# choice nodes and call nodes -@test length(ir.choice_nodes) == 4 -@test length(ir.call_nodes) == 1 +# call nodes +@test length(ir.call_nodes) == 5 # inlier_std -inlier_std = ir.choice_nodes[1] +inlier_std = ir.call_nodes[1] @test inlier_std.addr == :inlier_std @test inlier_std.typ == QuoteNode(Float64) -@test inlier_std.dist == gamma +@test inlier_std.generative_function == gamma @test length(inlier_std.inputs) == 2 # outlier_std -outlier_std = ir.choice_nodes[2] +outlier_std = ir.call_nodes[2] @test outlier_std.addr == :outlier_std @test outlier_std.typ == QuoteNode(Float64) -@test outlier_std.dist == gamma +@test outlier_std.generative_function == gamma @test length(outlier_std.inputs) == 2 # slope -slope = ir.choice_nodes[3] +slope = ir.call_nodes[3] @test slope.addr == :slope @test slope.typ == QuoteNode(Float64) -@test slope.dist == normal +@test slope.generative_function == normal @test length(slope.inputs) == 2 # intercept -intercept = ir.choice_nodes[4] +intercept = ir.call_nodes[4] @test intercept.addr == :intercept @test intercept.typ == QuoteNode(Float64) -@test intercept.dist == normal +@test intercept.generative_function == normal @test length(intercept.inputs) == 2 # data -ys = ir.call_nodes[1] +ys = ir.call_nodes[5] @test ys.addr == :data @test ys.typ == QuoteNode(PersistentVector{Float64}) @test ys.generative_function == data_fn @@ -275,8 +273,8 @@ ret = get_node_by_addr(ir, :x) @test isa(ret.inputs[1], Gen.JuliaNode) # () -> 0.5 @test ret.inputs[2] === i at = ret.generative_function -@test isa(at, Gen.ChoiceAtCombinator) -@test at.dist == bernoulli +@test isa(at, Gen.CallAtCombinator) +@test at.kernel == bernoulli # at_choice_example_2 ir = Gen.get_ir(typeof(at_choice_example_2)) @@ -291,8 +289,8 @@ ret = get_node_by_addr(ir, :x) at = ret.generative_function @test isa(at, Gen.CallAtCombinator) at2 = at.kernel -@test isa(at2, Gen.ChoiceAtCombinator) -@test at2.dist == bernoulli +@test isa(at2, Gen.CallAtCombinator) +@test at2.kernel == bernoulli end @@ -394,7 +392,7 @@ ir2 = Gen.get_ir(typeof(f2)) return_node1 = ir1.return_node return_node2 = ir2.return_node @test isa(return_node2, typeof(return_node1)) -@test return_node2.dist == return_node1.dist +@test return_node2.generative_function == return_node1.generative_function inputs1 = return_node1.inputs inputs2 = return_node2.inputs @@ -603,12 +601,10 @@ tr = simulate(bar1, ()) ch = get_choices(tr) @test has_value(ch, :x) @test !has_value(ch, :y) -@test_throws KeyError get_submap(ch, :x) @test has_value(get_submap(ch, :a), :b) @test get_submap(ch, :y) == EmptyChoiceMap() -@test length(get_values_shallow(ch)) == 1 -@test length(get_submaps_shallow(ch)) == 1 - +@test length(collect(get_values_shallow(ch))) == 1 +@test length(collect(get_submaps_shallow(ch))) == 2 end @testset "returning a SML function from macro" begin diff --git a/test/dynamic_choicemap_benchmark.jl b/test/dynamic_choicemap_benchmark.jl new file mode 100644 index 000000000..3724e44de --- /dev/null +++ b/test/dynamic_choicemap_benchmark.jl @@ -0,0 +1,27 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +cm = choicemap((:a, 1), (:b => :c, 2)) + +println("dynamic choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(cm) +end + +println("dynamic choicemap nested lookup:") +for _=1:4 + @time many_nested(cm) +end \ No newline at end of file diff --git a/test/gen_fn_interface.jl b/test/gen_fn_interface.jl index be41ae51f..c7bb0e9a1 100644 --- a/test/gen_fn_interface.jl +++ b/test/gen_fn_interface.jl @@ -19,7 +19,7 @@ for (lang, f) in [:dynamic => f_dynamic, # sanity-check that `update` did what it's supposed to. @test get_args(trace1) == (5, 6) @test trace1[:z] == 0 - @test :z in keys(get_values_shallow(discard)) + @test :z in map(first, get_values_shallow(discard)) end @testset "regenerate(...) shorthand assuming unchanged args ($lang modeling lang)" begin diff --git a/test/inference/trace_translators.jl b/test/inference/trace_translators.jl index 6008ff5cf..fc2c365de 100644 --- a/test/inference/trace_translators.jl +++ b/test/inference/trace_translators.jl @@ -97,7 +97,7 @@ end else @write(p2_trace[:z], true, :discrete) x = @read(p1_trace[:x], :continuous) - i = ceil(x * 10) + i = Int(ceil(x * 10)) @write(p2_trace[:i], i, :discrete) @write(q2_trace[:dx], x - (i-1)/10, :continuous) end @@ -135,8 +135,8 @@ end @transform f (p1_trace, q1_trace) to (p2_trace, q2_trace) begin x = @read(p1_trace[:x], :continuous) y = @read(p1_trace[:y], :continuous) - i = ceil(x * 10) - j = ceil(y * 10) + i = Int(ceil(x * 10)) + j = Int(ceil(y * 10)) @write(p2_trace[:i], i, :discrete) @write(p2_trace[:j], j, :discrete) @write(q2_trace[:dx], x - (i-1)/10, :continuous) diff --git a/test/modeling_library/call_at.jl b/test/modeling_library/call_at.jl index b27f0130d..1985c610a 100644 --- a/test/modeling_library/call_at.jl +++ b/test/modeling_library/call_at.jl @@ -1,4 +1,4 @@ -@testset "call_at combinator" begin +@testset "call_at combinator on non-distribution" begin @gen (grad) function foo((grad)(x::Float64)) return x + @trace(normal(x, 1), :y) @@ -20,7 +20,7 @@ y = choices[3 => :y] @test isapprox(weight, logpdf(normal, y, 0.4, 1)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end @testset "generate" begin @@ -32,7 +32,7 @@ y = choices[3 => :y] @test get_retval(trace) == 0.4 + y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # with constraints y = 1.234 @@ -44,7 +44,7 @@ @test get_retval(trace) == 0.4 + y @test isapprox(weight, logpdf(normal, y, 0.4, 1.)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end function get_trace() @@ -71,7 +71,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isempty(discard) @@ -86,12 +86,12 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) # change kernel_args, different key, with constraint @@ -103,12 +103,12 @@ choices = get_choices(new_trace) @test choices[4 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) end @@ -121,7 +121,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isapprox(get_score(new_trace), logpdf(normal, y, 0.2, 1)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) y_new = choices[3 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -144,7 +144,7 @@ choices = get_choices(new_trace) y_new = choices[4 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -171,9 +171,9 @@ @test choices[3 => :y] == y @test isapprox(gradients[3 => :y], logpdf_grad(normal, y, 0.4, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 0 - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(input_grads) == 2 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.4, 1.0)[2] + retval_grad) @test input_grads[2] == nothing # the key has no gradient diff --git a/test/modeling_library/choice_at.jl b/test/modeling_library/choice_at.jl index 080b1b461..69eb52498 100644 --- a/test/modeling_library/choice_at.jl +++ b/test/modeling_library/choice_at.jl @@ -1,6 +1,6 @@ -@testset "choice_at combinator" begin +@testset "call_at combinator on distribution" begin - at = choice_at(bernoulli, Int) + at = call_at(bernoulli, Int) @testset "assess" begin choices = choicemap() @@ -15,7 +15,7 @@ @test isapprox(weight, value ? log(0.4) : log(0.6)) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end @testset "generate" begin @@ -27,7 +27,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 # with constraints constraints = choicemap() @@ -39,7 +39,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end function get_trace() @@ -65,7 +65,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isempty(discard) @@ -78,12 +78,12 @@ choices = get_choices(new_trace) @test choices[3] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 # change kernel_args, different key, with constraint constraints = choicemap() @@ -93,12 +93,12 @@ choices = get_choices(new_trace) @test choices[4] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 end @testset "regenerate" begin @@ -110,7 +110,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isapprox(get_score(new_trace), log(0.2)) @@ -122,7 +122,7 @@ choices = get_choices(new_trace) value = choices[3] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) value = choices[4] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -143,7 +143,7 @@ y = 1.2 constraints = choicemap() set_value!(constraints, 3, y) - (trace, _) = generate(choice_at(normal, Int), (0.0, 1.0, 3), constraints) + (trace, _) = generate(call_at(normal, Int), (0.0, 1.0, 3), constraints) # not selected (input_grads, choices, gradients) = choice_gradients( @@ -163,9 +163,9 @@ @test choices[3] == y @test isapprox(gradients[3], logpdf_grad(normal, y, 0.0, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 1 - @test length(collect(get_submaps_shallow(gradients))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 0 @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test length(input_grads) == 3 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.0, 1.0)[2]) @test isapprox(input_grads[2], logpdf_grad(normal, y, 0.0, 1.0)[3]) diff --git a/test/modeling_library/map.jl b/test/modeling_library/map.jl index 3c1f820fe..38a4e4aca 100644 --- a/test/modeling_library/map.jl +++ b/test/modeling_library/map.jl @@ -402,4 +402,19 @@ @test isapprox(get_param_grad(foo, :std), expected_std_grad) end + @testset "map over distribution" begin + flip_coins = Map(bernoulli) + coinflips_tr, weight = generate(flip_coins, (fill(0.4, 100),)) + @test weight == 0. + @test coinflips_tr[20] isa Bool + choices = get_choices(coinflips_tr) + @test get_submap(choices, 42) isa ValueChoiceMap{Bool} + val42 = get_value(choices, 42) + new_tr, weight, retdiff, discard = update(coinflips_tr, (fill(0.4, 100),), (NoChange(),), choicemap((42, !val42))) + @test new_tr[42] == !val42 + expected_score_change = logpdf(bernoulli, !val42, 0.4) - logpdf(bernoulli, val42, 0.4) + @test isapprox(get_score(new_tr) - get_score(coinflips_tr), expected_score_change) + @test isapprox(weight, expected_score_change) + end + end diff --git a/test/modeling_library/recurse.jl b/test/modeling_library/recurse.jl index 7fe7a9592..97725f968 100644 --- a/test/modeling_library/recurse.jl +++ b/test/modeling_library/recurse.jl @@ -197,9 +197,9 @@ end @test choices[(4, Val(:production)) => :rule] == 4 @test choices[(4, Val(:aggregation)) => :prefix] == false @test discard[(3, Val(:aggregation)) => :prefix] == true - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 @test length(collect(get_values_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 1 @test retdiff == UnknownChange() diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index 194998eb1..ee12e07c2 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -26,7 +26,7 @@ foo = Unfold(kernel) x3 = trace[3 => :x] choices = get_choices(trace) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 expected_score = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x2, x1 * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -53,7 +53,7 @@ foo = Unfold(kernel) @test choices[1 => :x] == x1 @test choices[3 => :x] == x3 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x2 = choices[2 => :x] expected_weight = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -75,7 +75,7 @@ foo = Unfold(kernel) beta = 0.3 (choices, weight, retval) = propose(foo, (3, x_init, alpha, beta)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x1 = choices[1 => :x] x2 = choices[2 => :x] x3 = choices[3 => :x] diff --git a/test/runtests.jl b/test/runtests.jl index 98338c227..86c297db7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -112,4 +112,4 @@ include("optional_args.jl") include("static_ir/static_ir.jl") include("tilde_sugar.jl") include("inference/inference.jl") -include("modeling_library/modeling_library.jl") +include("modeling_library/modeling_library.jl") \ No newline at end of file diff --git a/test/static_choicemap_benchmark.jl b/test/static_choicemap_benchmark.jl new file mode 100644 index 000000000..7b62fe8cf --- /dev/null +++ b/test/static_choicemap_benchmark.jl @@ -0,0 +1,48 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +scm = StaticChoiceMap(a=1, b=StaticChoiceMap(c=2)) + +println("static choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(scm) +end + +println("static choicemap nested lookup:") +for _=1:4 + @time many_nested(scm) +end + +@gen (static) function inner() + c ~ normal(0, 1) +end +@gen (static) function outer() + a ~ normal(0, 1) + b ~ inner() +end + +tr, _ = generate(outer, ()) +choices = get_choices(tr) + +println("static gen function choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(choices) +end + +println("static gen function choicemap nested lookup:") +for _=1:4 + @time many_nested(choices) +end diff --git a/test/static_inference_benchmark.jl b/test/static_inference_benchmark.jl new file mode 100644 index 000000000..953af236b --- /dev/null +++ b/test/static_inference_benchmark.jl @@ -0,0 +1,21 @@ +using Gen + +@gen (static, diffs) function foo() + a ~ normal(0, 1) + b ~ normal(a, 1) + c ~ normal(b, 1) +end + +observations = StaticChoiceMap(choicemap((:b,2), (:c,1.5))) +tr, _ = generate(foo, (), observations) + +function run_inference(trace) + tr = trace + for _=1:10^3 + tr, acc = mh(tr, select(:a)) + end +end + +for _=1:4 + @time run_inference(tr) +end \ No newline at end of file diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 44a308b83..cc406b1c7 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -361,12 +361,12 @@ end @test get_value(value_trie, :out) == out @test get_value(value_trie, :bar => :z) == z @test !has_value(value_trie, :b) # was not selected - @test length(get_submaps_shallow(value_trie)) == 1 - @test length(get_values_shallow(value_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(value_trie))) == 1 + @test length(collect(get_values_shallow(value_trie))) == 2 # check gradient trie - @test length(get_submaps_shallow(gradient_trie)) == 1 - @test length(get_values_shallow(gradient_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(gradient_trie))) == 1 + @test length(collect(get_values_shallow(gradient_trie))) == 2 @test !has_value(gradient_trie, :b) # was not selected @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx))