From 23f79a8592f16a0275c2c6a959d17f7edae91090 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sun, 17 May 2020 12:24:50 -0400 Subject: [PATCH 01/45] first draft of core functionality --- src/choice_map2/array_interface.jl | 92 ++++++++++ src/choice_map2/choice_map.jl | 247 ++++++++++++++++++++++++++ src/choice_map2/dynamic_choice_map.jl | 153 ++++++++++++++++ src/choice_map2/nested_view.jl | 81 +++++++++ src/choice_map2/static_choice_map.jl | 131 ++++++++++++++ 5 files changed, 704 insertions(+) create mode 100644 src/choice_map2/array_interface.jl create mode 100644 src/choice_map2/choice_map.jl create mode 100644 src/choice_map2/dynamic_choice_map.jl create mode 100644 src/choice_map2/nested_view.jl create mode 100644 src/choice_map2/static_choice_map.jl diff --git a/src/choice_map2/array_interface.jl b/src/choice_map2/array_interface.jl new file mode 100644 index 000000000..f88c5b116 --- /dev/null +++ b/src/choice_map2/array_interface.jl @@ -0,0 +1,92 @@ +### interface for to_array and fill_array ### + +""" + arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} + +Populate an array with values of choices in the given assignment. + +It is an error if each of the values cannot be coerced into a value of the +given type. + +Implementation + +The default implmentation of `fill_array` will populate the array by sorting +the addresses of the choicemap using the `sort` function, then iterating over +each submap in this order and filling the array for that submap. + +To override the default implementation of `to_array`, +a concrete subtype `T <: ChoiceMap` should implement the following method: + + n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Populate `arr` with values from the given assignment, starting at `start_idx`, +and return the number of elements in `arr` that were populated. + +(This is for performance; it is more efficient to fill in values in a preallocated array +by implementing `_fill_array!` than to construct discontiguous arrays for each submap and then merge them.) +""" +function to_array(choices::ChoiceMap, ::Type{T}) where {T} + arr = Vector{T}(undef, 32) + n = _fill_array!(choices, arr, 1) + @assert n <= length(arr) + resize!(arr, n) + arr +end + +function _fill_array!(c::ValueChoiceMap{<:T}, arr::Vector{T}, start_idx::Int) where {T} + if length(arr) <: start_idx + resize!(arr, 2 * start_idx) + end + arr[start_idx] = get_value(c) + 1 +end + +# default _fill_array! implementation +function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + key_to_submap = collect(get_submaps_shallow(choices)) + sort!(key_to_submap, by = ((key, submap),) -> key) + idx = start_idx + for (key, submap) in key_to_submap + n_written = _fill_array!(submap, arr, idx) + idx += n_written + end + idx - start_idx +end + +""" + choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) + +Return an assignment with the same address structure as a prototype +assignment, but with values read off from the given array. + +It is an error if the number of choices in the prototype assignment +is not equal to the length the array. + +The order in which addresses are populated with values from the array +should match the order in which the array is populated with values +in a call to `to_array(proto_choices, T)`. By default, +this means sorting the top-level addresses for `proto_choices` +and then filling in the submaps depth-first in this order. + +# Implementation + +To support `from_array`, a concrete subtype `T <: ChoiceMap` must implement +the following method: + + (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Return an assignment with the same address structure as a prototype assignment, +but with values read off from `arr`, starting at position `start_idx`. Return the +number of elements read from `arr`. +""" +function from_array(proto_choices::ChoiceMap, arr::Vector) + (n, choices) = _from_array(proto_choices, arr, 1) + if n != length(arr) + error("Dimension mismatch: $n, $(length(arr))") + end + choices +end + +function _from_array(::ValueChoiceMap, arr::Vector, start_idx::Int) + ValueChoiceMap(arr[start_idx]) +end \ No newline at end of file diff --git a/src/choice_map2/choice_map.jl b/src/choice_map2/choice_map.jl new file mode 100644 index 000000000..d7e7101fe --- /dev/null +++ b/src/choice_map2/choice_map.jl @@ -0,0 +1,247 @@ +######################### +# choice map interface # +######################### + +""" + get_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for each top-level address associated with `choices`. +(This includes `ValueChoiceMap`s.) +""" +function get_submaps_shallow end + +""" + get_submap(choices::ChoiceMap, addr) + +Return the submap at the given address, or `EmptyChoiceMap` +if there is no submap at the given address. +""" +function get_submap end + +# provide _get_submap so when users overwrite get_submap(choices::CustomChoiceMap, addr::Pair) +# they can just call _get_submap for convenience if they want +@inline function _get_submap(choices::ChoiceMap, addr::Pair) + (first, rest) = addr + submap = get_submap(choices, first) + get_submap(submap, rest) +end +@inline get_submap(choices::ChoiceMap, addr::Pair) = _get_submap(choices, addr) + +""" + has_value(choices::ChoiceMap) + +Returns true if `choices` is a `ValueChoiceMap`. + + has_value(choices::ChoiceMap, addr) + +Returns true if `choices` has a value stored at address `addr`. +""" +function has_value end +@inline has_value(::ChoiceMap) = false +@inline has_value(c::ChoiceMap, addr) = has_value(get_submap(c, addr)) + +""" + get_value(choices::ChoiceMap) + +Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; +throws a `KeyError` if `choices` is not a `ValueChoiceMap`. + + get_value(choices::ChoiceMap, addr) +Returns the value stored in the submap with address `addr` or throws +a `KeyError` if no value exists at this address. + +A syntactic sugar is `Base.getindex`: + + value = choices[addr] +""" +function get_value end +get_value(::ChoiceMap) = throw(KeyError(nothing)) +get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) + +# get_values_shallow and get_nonvalue_submaps_shallow are just filters on get_submaps_shallow +""" + get_values_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, value)` +for each value stored at a top-level address in `choices`. +""" +function get_values_shallow(choices::ChoiceMap) + ( + (addr, get_value(submap)) + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) + ) +end + +""" + get_nonvalue_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for every top-level submap stored in `choices` which is +not a `ValueChoiceMap`. +""" +function get_nonvalue_submaps_shallow(choices::ChoiceMap) + filter(! ∘ has_value, get_submaps_shallow(choices)) +end + +# a choicemap is empty if it has no submaps and no value +Base.isempty(c::ChoiceMap) = isempty(get_submaps_shallow(c)) && !has_value(c) + +""" + abstract type ChoiceMap end + +Abstract type for maps from hierarchical addresses to values. +""" +abstract type ChoiceMap end + +""" + EmptyChoiceMap + +A choicemap with no submaps or values. +""" +struct EmptyChoiceMap <: ChoiceMap end + +@inline has_value(::EmptyChoiceMap, addr...) = false +@inline get_value(::EmptyChoiceMap) = throw(KeyError(nothing)) +@inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() +@inline Base.isempty(::EmptyChoiceMap) = true +@inline get_submaps_shallow(::EmptyChoiceMap) = () + +""" + ValueChoiceMap + +A leaf-node choicemap. Stores a single value. +""" +struct ValueChoiceMap{T} <: ChoiceMap + val::T +end + +@inline has_value(choices::ValueChoiceMap) = true +@inline get_value(choices::ValueChoiceMap) = choices.val +@inline get_submap(choices::ValueChoiceMap, addr) = EmptyChoiceMap() +@inline get_submaps_shallow(choices::ValueChoiceMap) = () +Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val +Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) + +""" + choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) + +Merge two choice maps. + +It is an error if the choice maps both have values at the same address, or if +one choice map has a value at an address that is the prefix of the address of a +value in the other choice map. +""" +function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) + choices = DynamicChoiceMap() + for (key, submap) in get_submaps_shallow(choices1) + set_submap!(choices, key, merge(submap, get_submap(choices2, key))) + end + choices +end +Base.merge(c::ChoiceMap, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::ChoiceMap) = c +Base.merge(c::ValueChoiceMap, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::ValueChoiceMap) = c +Base.merge(::ValueChoiceMap, ::ChoiceMap) = error("ValueChoiceMaps cannot be merged") +Base.merge(::ChoiceMap, ::ValueChoiceMap) = error("ValueChoiceMaps cannot be merged") + +""" +Variadic merge of choice maps. +""" +function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) + reduce(Base.merge, choices_rest; init=choices1) +end + +function Base.:(==)(a::ChoiceMap, b::ChoiceMap) + for (addr, submap) in get_submaps_shallow(a) + if get_submap(b, addr) != submap + return false + end + end + return true +end + +function Base.isapprox(a::ChoiceMap, b::ChoiceMap) + for (addr, submap) in get_submaps_shallow(a) + if !isapprox(get_submap(b, addr), submap) + return false + end + end + return true +end + +""" + selected_choices = get_selected(choices::ChoiceMap, selection::Selection) + +Filter the choice map to include only choices in the given selection. + +Returns a new choice map. +""" +function get_selected( + choices::ChoiceMap, selection::Selection) + # TODO: return a `FilteringChoiceMap` which does this filtering lazily! + output = choicemap() + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) && addr in selection + output[addr] = get_value(submap) + else + subselection = selection[addr] + set_submap!(output, addr, get_selected(submap, subselection)) + end + end + output +end + +function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + indent_vert_str = join(indent_vert) + indent_vert_last_str = join(indent_vert_last) + indent_str = join(indent) + indent_last_str = join(indent_last) + key_and_values = collect(get_values_shallow(choices)) + key_and_submaps = collect(get_nonvalue_submaps_shallow(choices)) + n = length(key_and_values) + length(key_and_submaps) + cur = 1 + for (key, value) in key_and_values + # For strings, `print` is what we want; `Base.show` includes quote marks. + # https://docs.julialang.org/en/v1/base/io-network/#Base.print + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + for (key, submap) in key_and_submaps + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") + _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) + _show_pretty(io, choices, 0, ()) +end + +export ChoiceMap, ValueChoiceMap, EmptyChoiceMap +export get_submap, get_submaps_shallow +export get_value, has_value +export get_values_shallow, get_nonvalue_submaps_shallow + +include("array_interface.jl") +include("dynamic_choice_map.jl") +include("static_choice_map.jl") +include("nested_view.jl") \ No newline at end of file diff --git a/src/choice_map2/dynamic_choice_map.jl b/src/choice_map2/dynamic_choice_map.jl new file mode 100644 index 000000000..a93a49021 --- /dev/null +++ b/src/choice_map2/dynamic_choice_map.jl @@ -0,0 +1,153 @@ +####################### +# dynamic assignment # +####################### + +struct DynamicChoiceMap <: ChoiceMap + submaps::Dict{Any, <:ChoiceMap} +end + +""" + struct DynamicChoiceMap <: ChoiceMap .. end + +A mutable map from arbitrary hierarchical addresses to values. + + choices = DynamicChoiceMap() + +Construct an empty map. + + choices = DynamicChoiceMap(tuples...) + +Construct a map containing each of the given (addr, value) tuples. +""" +function DynamicChoiceMap() + DynamicChoiceMap(Dict()) +end + +function DynamicChoiceMap(tuples...) + choices = DynamicChoiceMap() + for (addr, value) in tuples + choices[addr] = value + end + choices +end + +""" + choices = DynamicChoiceMap(other::ChoiceMap) + +Copy a choice map, returning a mutable choice map. +""" +function DynamicChoiceMap(other::ChoiceMap) + choices = DynamicChoiceMap() + for (addr, submap) in get_submaps_shallow(other) + if choices isa ValueChoiceMap + set_submap!(choices, addr, submap) + else + set_submap!(choices, addr, DynamicChoiceMap(submap)) + end + end +end + +DynamicChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a DynamicChoiceMap") + +""" + choices = choicemap() + +Construct an empty mutable choice map. +""" +function choicemap() + DynamicChoiceMap() +end + +""" + choices = choicemap(tuples...) + +Construct a mutable choice map initialized with given address, value tuples. +""" +function choicemap(tuples...) + DynamicChoiceMap(tuples...) +end + +get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps +function get_submap(choices::DynamicChoiceMap, addr) + if haskey(choices.submaps, addr) + choices.submaps[addr] + else + EmptyChoiceMap() + end +end +get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) +Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) + +# mutation (not part of the assignment interface) + +""" + set_value!(choices::DynamicChoiceMap, addr, value) + +Set the given value for the given address. + +Will cause any previous value or sub-assignment at this address to be deleted. +It is an error if there is already a value present at some prefix of the given address. + +The following syntactic sugar is provided: + + choices[addr] = value +""" +function set_value!(choices::DynamicChoiceMap, addr, value) + delete!(choices.submaps, addr) + choices.submaps[addr] = ValueChoiceMap(value) +end + +function set_value!(choices::DynamicChoiceMap, addr::Pair, value) + (first, rest) = addr + if !haskey(choices.submaps, first) + choices.submaps[first] = DynamicChoiceMap() + elseif has_value(choices.submaps[first]) + error("Tried to create assignment at $first but there was already a value there.") + end + set_value!(choices.submaps[first], rest, value) +end + +""" + set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) + +Replace the sub-assignment rooted at the given address with the given sub-assignment. +Set the given value for the given address. + +Will cause any previous value or sub-assignment at the given address to be deleted. +It is an error if there is already a value present at some prefix of address. +""" +function set_submap!(choices::DynamicChoiceMap, addr, new_node) + delete!(choices.submaps, addr) + if !isempty(new_node) + choices.submaps[addr] = new_node + end +end + +function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) + (first, rest) = addr + if !haskey(choices.submaps, first) + choices.submaps[first] = DynamicChoiceMap() + elseif has_value(choices.submaps[first]) + error("Tried to create assignment at $first but there was already a value there.") + end + set_submap!(choices.submaps[first], rest, new_node) +end + +Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) + +function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + choices = DynamicChoiceMap() + keys_sorted = sort(collect(keys(choices.submaps))) + idx = start_idx + for key in keys_sorted + (n_read, submap) = _from_array(proto_choices.submaps[key], arr, idx) + idx += n_read + choices.submaps[key] = submap + end + (idx - start_idx, choices) +end + +export DynamicChoiceMap +export choicemap +export set_value! +export set_submap! \ No newline at end of file diff --git a/src/choice_map2/nested_view.jl b/src/choice_map2/nested_view.jl new file mode 100644 index 000000000..6693234fb --- /dev/null +++ b/src/choice_map2/nested_view.jl @@ -0,0 +1,81 @@ +############################################ +# Nested-dict–like accessor for choicemaps # +############################################ + +""" +Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than +the default syntax which looks like a flat dict of full keypaths. + +```jldoctest +julia> using Gen +julia> c = choicemap((:a, 1), + (:b => :c, 2)); +julia> cv = nested_view(c); +julia> c[:a] == cv[:a] +true +julia> c[:b => :c] == cv[:b][:c] +true +julia> length(cv) +2 +julia> length(cv[:b]) +1 +julia> sort(collect(keys(cv))) +[:a, :b] +julia> sort(collect(keys(cv[:b]))) +[:c] +``` +""" +struct ChoiceMapNestedView + choice_map::ChoiceMap +end + +ChoiceMapNestedView(cm::ValueChoiceMap) = get_value(cm) +ChoiceMapNestedView(::EmptyChoiceMap) = error("Can't convert an emptychoicemap to nested view.") + +function Base.getindex(choices::ChoiceMapNestedView, addr) + ChoiceMapNestedView(get_submap(choices, addr)) +end + +function Base.iterate(c::ChoiceMapNestedView) + itr = ((k, ChoiceMapNestedView(s)) for (k, s) in get_submaps_shallow(c.choice_map)) + r = Base.iterate(itr) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +function Base.iterate(c::ChoiceMapNestedView, state) + (itr, st) = state + r = Base.iterate(itr, st) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +# TODO: Allow different implementations of this method depending on the +# concrete type of the `ChoiceMap`, so that an already-existing data structure +# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it +# exists. +Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) + +function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) + a.choice_map = b.choice_map +end +function Base.length(cv::ChoiceMapNestedView) + length(collect(get_submaps_shallow(cv.choice_map))) +end +function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) + Base.show(io, MIME"text/plain"(), c.choice_map) +end + +nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) + +# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling +# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and +# aux data together. + +export nested_view \ No newline at end of file diff --git a/src/choice_map2/static_choice_map.jl b/src/choice_map2/static_choice_map.jl new file mode 100644 index 000000000..e5e2d89e2 --- /dev/null +++ b/src/choice_map2/static_choice_map.jl @@ -0,0 +1,131 @@ +###################### +# static assignment # +###################### + +struct StaticChoiceMap{Addrs, SubmapTypes} <: ChoiceMap + submaps::NamedTuple{Addrs, SubmapTypes} +end + +@inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) +@inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) +@inline get_submap(choices::StaticChoiceMap, addr::Symbol) = static_get_submap(choices, Val(addr)) + +# TODO: profiling! +@generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} + if A in Addrs + quote choices.submaps[A] end + else + quote EmptyChoiceMap() end + end +end + +static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) + +# convert a nonvalue choicemap all of whose top-level-addresses +# are symbols into a staticchoicemap at the top level +function StaticChoiceMap(other::ChoiceMap) + keys_and_nodes = get_submaps_shallow(other) + (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + StaticChoiceMap(NamedTuple{addrs}(submaps)) +end +StaticChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a StaticChoiceMap") + +# TODO: deep conversion to static choicemap + +""" + choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) + +Return an assignment that contains `choices1` as a sub-assignment under `key1` +and `choices2` as a sub-assignment under `key2`. +""" +function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) + StaticChoiceMap(NamedTuple{(key1, key2)}((choices1, choices2))) +end + +""" + (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) + +Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. + +It is an error if there are any submaps at keys other than `key1` and `key2`. +""" +function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) + if length(collect(get_submaps_shallow(choices))) != 2 + error("Not a pair") + end + (get_submap(choices, key1), get_submap(choices, key2)) +end + +@generated function Base.merge(choices1::StaticChoiceMap{Addrs1, SubmapTypes1}, + choices2::StaticChoiceMap{Addrs2, SubmapTypes2}) where {Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} + + addr_to_type1 = Dict{Symbol, ::Type{<:ChoiceMap}}() + addr_to_type2 = Dict{Symbol, ::Type{<:ChoiceMap}}() + for (i, addr) in enumerate(Addrs1) + addr_to_type1[addr] = SubmapTypes1.parameters[i] + end + for (i, addr) in enumerate(Addrs2) + addr_to_type2[addr] = SubmapTypes2.parameters[i] + end + + merged_addrs = Tuple(union(Set(Addrs1), Set(Addrs2))) + submap_exprs = [] + + for addr in merged_addrs + type1 = get(addr_to_type1, addr, EmptyChoiceMap) + type2 = get(addr_to_type2, addr, EmptyChoiceMap) + if ((type1 <: ValueChoiceMap && type2 != EmptyChoiceMap) + || (type2 <: ValueChoiceMap && type1 != EmptyChoiceMap)) + error( "One choicemap has a value at address $addr; the other is nonempty at $addr. Cannot merge.") + end + if type1 <: ValueChoiceMap + push!(submap_exprs, + quote choices1.submaps[$addr] end + ) + elseif type2 <: ValueChoiceMap + push!(submap_exprs, + quote choices2.submaps[$addr] end + ) + else + push!(submap_exprs, + quote merge(choices1.submaps[$addr], choices2.submaps[$addr]) end + ) + end + end + + quote + StaticChoiceMap{$merged_addrs}(submap_exprs...) + end +end + +@generated function _from_array!(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, + arr::Vector{T}, start_idx::Int) where {T, Addrs, SubmapTypes} + + perm = sortperm(Addrs) + sorted_addrs = Addrs[perm] + submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) + + exprs = [quote idx = start_idx end] + + for (idx, addr) in zip(perm, sorted_addrs) + submap_var_name = gensym(addr) + submap_var_names[idx] = submap_var_name + push!(exprs, + quote + (n_read, submap_var_name = _from_array(proto_choices.submaps[$addr], arr, idx) + idx += n_read + end + ) + end + + quote + $(exprs...) + submaps = NamedTuple{Addrs}(( $(submap_var_names...) )) + choices = StaticChoiceMap{Addrs, SubmapTypes}(submaps) + (idx - start_idx, choices) + end +end + +export StaticChoiceMap +export pair, unpair +export static_get_submap, static_get_value \ No newline at end of file From c9b1d4982e5f8f4903254adc982d4d5a216c5580 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sun, 17 May 2020 12:33:00 -0400 Subject: [PATCH 02/45] add support for address schemas --- src/choice_map2/choice_map.jl | 9 +++++++++ src/choice_map2/dynamic_choice_map.jl | 2 ++ src/choice_map2/static_choice_map.jl | 4 ++++ 3 files changed, 15 insertions(+) diff --git a/src/choice_map2/choice_map.jl b/src/choice_map2/choice_map.jl index d7e7101fe..0ebb19f09 100644 --- a/src/choice_map2/choice_map.jl +++ b/src/choice_map2/choice_map.jl @@ -60,6 +60,13 @@ get_value(::ChoiceMap) = throw(KeyError(nothing)) get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) +""" +schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} + +Return the (top-level) address schema for the given choice map. +""" +function get_address_schema end + # get_values_shallow and get_nonvalue_submaps_shallow are just filters on get_submaps_shallow """ get_values_shallow(choices::ChoiceMap) @@ -108,6 +115,7 @@ struct EmptyChoiceMap <: ChoiceMap end @inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() @inline Base.isempty(::EmptyChoiceMap) = true @inline get_submaps_shallow(::EmptyChoiceMap) = () +@inline get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() """ ValueChoiceMap @@ -124,6 +132,7 @@ end @inline get_submaps_shallow(choices::ValueChoiceMap) = () Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) +@inline get_address_schema(::Type{<:ValueChoiceMap}) = EmptyAddressSchema() """ choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) diff --git a/src/choice_map2/dynamic_choice_map.jl b/src/choice_map2/dynamic_choice_map.jl index a93a49021..5dfca0b55 100644 --- a/src/choice_map2/dynamic_choice_map.jl +++ b/src/choice_map2/dynamic_choice_map.jl @@ -147,6 +147,8 @@ function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx: (idx - start_idx, choices) end +get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() + export DynamicChoiceMap export choicemap export set_value! diff --git a/src/choice_map2/static_choice_map.jl b/src/choice_map2/static_choice_map.jl index e5e2d89e2..3508762d7 100644 --- a/src/choice_map2/static_choice_map.jl +++ b/src/choice_map2/static_choice_map.jl @@ -126,6 +126,10 @@ end end end +function get_address_schema(::Type{StaticChoiceMap{Addrs, SubmapTypes}}) where {Addrs, SubmapTypes} + StaticAddressSchema(set(Addrs)) +end + export StaticChoiceMap export pair, unpair export static_get_submap, static_get_value \ No newline at end of file From 1e0a58997d4717eb687aba6306b8b556108475bb Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sun, 17 May 2020 12:42:52 -0400 Subject: [PATCH 03/45] update choicemap docs --- docs/src/ref/choice_maps.md | 43 +++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index c065b1b32..8d3f4200e 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -8,13 +8,20 @@ ChoiceMap Choice maps are constructed by users to express observations and/or constraints on the traces of generative functions. Choice maps are also returned by certain Gen inference methods, and are used internally by various Gen inference methods. +A choicemap a tree, whose leaf nodes store a single value, and whose internal nodes provide addresses +for sub-choicemaps. Leaf nodes have type: +```@docs +ValueChoiceMap +``` + Choice maps provide the following methods: ```@docs +get_submap +get_submaps_shallow has_value get_value -get_submap get_values_shallow -get_submaps_shallow +get_nonvalue_submaps_shallow to_array from_array get_selected @@ -50,3 +57,35 @@ choicemap set_value! set_submap! ``` + +## Implementing custom choicemap types + +To implement a custom choicemap, one must implement +`get_submap` and `get_submaps_shallow`. +To avoid method ambiguity with the default +`get_submap(::ChoiceMap, ::Pair)`, one must implement both +```julia +get_submap(::CustomChoiceMap, addr) +``` +and +```julia +get_submap(::CustomChoiceMap, addr::Pair) +``` +To use the default implementation of `get_submap(_, ::Pair)`, +one may define +```julia +get_submap(c::CustomChoiceMap, addr::Pair) = _get_choicemap(c, addr) +``` + +Once `get_submap` and `get_submaps_shallow` are defined, default +implementations are provided for: +- `has_value` +- `get_value` +- `get_values_shallow` +- `get_nonvalue_submaps_shallow` +- `to_array` +- `get_selected` + +If one wishes to support `from_array`, they must implement +`_from_array`, as described in the documentation for +[`from_array`](@ref). \ No newline at end of file From 623bc8fcba7fc81eecb039a13d861baf06102d57 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 16:57:46 -0400 Subject: [PATCH 04/45] refactoring and tests --- docs/src/ref/choice_maps.md | 2 +- src/Gen.jl | 2 +- src/choice_map.jl | 1009 ----------------- .../array_interface.jl | 20 +- src/{choice_map2 => choice_map}/choice_map.jl | 57 +- .../dynamic_choice_map.jl | 20 +- .../nested_view.jl | 7 +- .../static_choice_map.jl | 52 +- src/dynamic/dynamic.jl | 31 +- src/dynamic/generate.jl | 2 +- src/dynamic/trace.jl | 36 +- src/dynamic/update.jl | 26 +- src/inference/kernel_dsl.jl | 11 +- src/modeling_library/call_at/call_at.jl | 5 +- src/modeling_library/choice_at/choice_at.jl | 4 +- src/modeling_library/recurse/recurse.jl | 18 +- src/modeling_library/vector.jl | 4 - src/static_ir/backprop.jl | 21 +- src/static_ir/trace.jl | 75 +- src/static_ir/update.jl | 7 +- test/assignment.jl | 224 ++-- test/benchmark.md | 21 + test/dynamic_dsl.jl | 14 +- test/modeling_library/call_at.jl | 26 +- test/modeling_library/choice_at.jl | 26 +- test/modeling_library/recurse.jl | 4 +- test/modeling_library/unfold.jl | 6 +- test/optional_args.jl | 2 +- test/runtests.jl | 2 +- test/static_ir/static_ir.jl | 10 +- test/tilde_sugar.jl | 2 +- 31 files changed, 347 insertions(+), 1399 deletions(-) delete mode 100644 src/choice_map.jl rename src/{choice_map2 => choice_map}/array_interface.jl (83%) rename src/{choice_map2 => choice_map}/choice_map.jl (82%) rename src/{choice_map2 => choice_map}/dynamic_choice_map.jl (93%) rename src/{choice_map2 => choice_map}/nested_view.jl (93%) rename src/{choice_map2 => choice_map}/static_choice_map.jl (68%) create mode 100644 test/benchmark.md diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index 8d3f4200e..6c445df6f 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -30,7 +30,7 @@ Note that none of these methods mutate the choice map. Choice maps also implement: -- `Base.isempty`, which tests of there are no random choices in the choice map +- `Base.isempty`, which returns `false` if the choicemap contains no value or submaps, and `true` otherwise. - `Base.merge`, which takes two choice maps, and returns a new choice map containing all random choices in either choice map. It is an error if the choice maps both have values at the same address, or if one choice map has a value at an address that is the prefix of the address of a value in the other choice map. diff --git a/src/Gen.jl b/src/Gen.jl index 9f3da9e3a..fa2393596 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -37,7 +37,7 @@ include("backprop.jl") include("address.jl") # abstract and built-in concrete choice map data types -include("choice_map.jl") +include("choice_map/choice_map.jl") # a homogeneous trie data type (not for use as choice map) include("trie.jl") diff --git a/src/choice_map.jl b/src/choice_map.jl deleted file mode 100644 index b7891b40a..000000000 --- a/src/choice_map.jl +++ /dev/null @@ -1,1009 +0,0 @@ -######################### -# choice map interface # -######################### - -""" - schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} - -Return the (top-level) address schema for the given choice map. -""" -function get_address_schema end - -""" - submap = get_submap(choices::ChoiceMap, addr) - -Return the sub-assignment containing all choices whose address is prefixed by addr. - -It is an error if the assignment contains a value at the given address. If -there are no choices whose address is prefixed by addr then return an -`EmptyChoiceMap`. -""" -function get_submap end - -""" - value = get_value(choices::ChoiceMap, addr) - -Return the value at the given address in the assignment, or throw a KeyError if -no value exists. A syntactic sugar is `Base.getindex`: - - value = choices[addr] -""" -function get_value end - -""" - key_submap_iterable = get_submaps_shallow(choices::ChoiceMap) - -Return an iterable collection of tuples `(key, submap::ChoiceMap)` for each top-level key -that has a non-empty sub-assignment. -""" -function get_submaps_shallow end - -""" - has_value(choices::ChoiceMap, addr) - -Return true if there is a value at the given address. -""" -function has_value end - -""" - key_submap_iterable = get_values_shallow(choices::ChoiceMap) - -Return an iterable collection of tuples `(key, value)` for each -top-level key associated with a value. -""" -function get_values_shallow end - -""" - abstract type ChoiceMap end - -Abstract type for maps from hierarchical addresses to values. -""" -abstract type ChoiceMap end - -""" - Base.isempty(choices::ChoiceMap) - -Return true if there are no values in the assignment. -""" -function Base.isempty(::ChoiceMap) - true -end - -@inline get_submap(choices::ChoiceMap, addr) = EmptyChoiceMap() -@inline has_value(choices::ChoiceMap, addr) = false -@inline get_value(choices::ChoiceMap, addr) = throw(KeyError(addr)) -@inline Base.getindex(choices::ChoiceMap, addr) = get_value(choices, addr) - -@inline function _has_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - has_value(submap, rest) -end - -@inline function _get_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_value(submap, rest) -end - -@inline function _get_submap(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_submap(submap, rest) -end - -function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) - VERT = '\u2502' - PLUS = '\u251C' - HORZ = '\u2500' - LAST = '\u2514' - indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) - indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) - for i in vert_bars - indent_vert[i] = VERT - indent[i] = VERT - indent_last[i] = VERT - end - indent_vert_str = join(indent_vert) - indent_vert_last_str = join(indent_vert_last) - indent_str = join(indent) - indent_last_str = join(indent_last) - key_and_values = collect(get_values_shallow(choices)) - key_and_submaps = collect(get_submaps_shallow(choices)) - n = length(key_and_values) + length(key_and_submaps) - cur = 1 - for (key, value) in key_and_values - # For strings, `print` is what we want; `Base.show` includes quote marks. - # https://docs.julialang.org/en/v1/base/io-network/#Base.print - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") - cur += 1 - end - for (key, submap) in key_and_submaps - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") - _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) - cur += 1 - end -end - -function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) - _show_pretty(io, choices, 0, ()) -end - -# assignments that have static address schemas should also support faster -# accessors, which make the address explicit in the type (Val(:foo) instaed of -# :foo) -function static_get_value end -function static_get_submap end - -function _fill_array! end -function _from_array end - -""" - arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} - -Populate an array with values of choices in the given assignment. - -It is an error if each of the values cannot be coerced into a value of the -given type. - -# Implementation - -To support `to_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Populate `arr` with values from the given assignment, starting at `start_idx`, -and return the number of elements in `arr` that were populated. -""" -function to_array(choices::ChoiceMap, ::Type{T}) where {T} - arr = Vector{T}(undef, 32) - n = _fill_array!(choices, arr, 1) - @assert n <= length(arr) - resize!(arr, n) - arr -end - -function _fill_array!(value::T, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx - resize!(arr, 2 * start_idx) - end - arr[start_idx] = value - 1 -end - -function _fill_array!(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx + length(value) - resize!(arr, 2 * (start_idx + length(value))) - end - arr[start_idx:start_idx+length(value)-1] = value - length(value) -end - - -""" - choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) - -Return an assignment with the same address structure as a prototype -assignment, but with values read off from the given array. - -The order in which addresses are populated is determined by the prototype -assignment. It is an error if the number of choices in the prototype assignment -is not equal to the length the array. - -# Implementation - -To support `from_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - - (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Return an assignment with the same address structure as a prototype assignment, -but with values read off from `arr`, starting at position `start_idx`, and the -number of elements read from `arr`. -""" -function from_array(proto_choices::ChoiceMap, arr::Vector) - (n, choices) = _from_array(proto_choices, arr, 1) - if n != length(arr) - error("Dimension mismatch: $n, $(length(arr))") - end - choices -end - -function _from_array(::T, arr::Vector{T}, start_idx::Int) where {T} - (1, arr[start_idx]) -end - -function _from_array(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - n_read = length(value) - (n_read, arr[start_idx:start_idx+n_read-1]) -end - - -""" - choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - -Merge two choice maps. - -It is an error if the choice maps both have values at the same address, or if -one choice map has a value at an address that is the prefix of the address of a -value in the other choice map. -""" -function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - choices = DynamicChoiceMap() - for (key, value) in get_values_shallow(choices1) - choices.leaf_nodes[key] = value - end - for (key, node1) in get_submaps_shallow(choices1) - node2 = get_submap(choices2, key) - node = merge(node1, node2) - choices.internal_nodes[key] = node - end - for (key, value) in get_values_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has leaf node at $key") - end - if haskey(choices.internal_nodes, key) - error("choices1 has internal node at $key and choices2 has leaf node at $key") - end - choices.leaf_nodes[key] = value - end - for (key, node) in get_submaps_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has internal node at $key") - end - if !haskey(choices.internal_nodes, key) - # otherwise it should already be included - choices.internal_nodes[key] = node - end - end - return choices -end - -""" -Variadic merge of choice maps. -""" -function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) - reduce(Base.merge, choices_rest; init=choices1) -end - -function Base.:(==)(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || (get_value(b, addr) != value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || (get_value(a, addr) != value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if submap != get_submap(b, addr) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if submap != get_submap(a, addr) - return false - end - end - return true -end - -function Base.isapprox(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || !isapprox(get_value(b, addr), value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || !isapprox(get_value(a, addr), value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if !isapprox(submap, get_submap(b, addr)) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if !isapprox(submap, get_submap(a, addr)) - return false - end - end - return true -end - - -export ChoiceMap -export get_address_schema -export get_submap -export get_value -export has_value -export get_submaps_shallow -export get_values_shallow -export static_get_value -export static_get_submap -export to_array, from_array - - -###################### -# static assignment # -###################### - -struct StaticChoiceMap{R,S,T,U} <: ChoiceMap - leaf_nodes::NamedTuple{R,S} - internal_nodes::NamedTuple{T,U} - isempty::Bool -end - -function StaticChoiceMap{R,S,T,U}(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - -function StaticChoiceMap(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - - -# invariant: all internal_nodes are nonempty - -function get_address_schema(::Type{StaticChoiceMap{R,S,T,U}}) where {R,S,T,U} - keys = Set{Symbol}() - for (key, _) in zip(R, S.parameters) - push!(keys, key) - end - for (key, _) in zip(T, U.parameters) - push!(keys, key) - end - StaticAddressSchema(keys) -end - -function Base.isempty(choices::StaticChoiceMap) - choices.isempty -end - -get_values_shallow(choices::StaticChoiceMap) = pairs(choices.leaf_nodes) -get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.internal_nodes) -has_value(choices::StaticChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::StaticChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) - -# NOTE: there is no static_has_value because this is known from the static -# address schema - -## has_value ## - -function has_value(choices::StaticChoiceMap, key::Symbol) - haskey(choices.leaf_nodes, key) -end - -## get_submap ## - -function get_submap(choices::StaticChoiceMap, key::Symbol) - if haskey(choices.internal_nodes, key) - choices.internal_nodes[key] - elseif haskey(choices.leaf_nodes, key) - throw(KeyError(key)) - else - EmptyChoiceMap() - end -end - -function static_get_submap(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.internal_nodes[A] -end - -## get_value ## - -function get_value(choices::StaticChoiceMap, key::Symbol) - choices.leaf_nodes[key] -end - -function static_get_value(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.leaf_nodes[A] -end - -# convert from any other schema that has only Val{:foo} addresses -function StaticChoiceMap(other::ChoiceMap) - leaf_keys_and_nodes = collect(get_values_shallow(other)) - internal_keys_and_nodes = collect(get_submaps_shallow(other)) - if length(leaf_keys_and_nodes) > 0 - (leaf_keys, leaf_nodes) = collect(zip(leaf_keys_and_nodes...)) - else - (leaf_keys, leaf_nodes) = ((), ()) - end - if length(internal_keys_and_nodes) > 0 - (internal_keys, internal_nodes) = collect(zip(internal_keys_and_nodes...)) - else - (internal_keys, internal_nodes) = ((), ()) - end - StaticChoiceMap( - NamedTuple{leaf_keys}(leaf_nodes), - NamedTuple{internal_keys}(internal_nodes), - isempty(other)) -end - -""" - choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - -Return an assignment that contains `choices1` as a sub-assignment under `key1` -and `choices2` as a sub-assignment under `key2`. -""" -function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - StaticChoiceMap(NamedTuple(), NamedTuple{(key1,key2)}((choices1, choices2)), - isempty(choices1) && isempty(choices2)) -end - -""" - (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - -Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. - -It is an error if there are any top-level values, or any non-empty top-level -sub-assignments at keys other than `key1` and `key2`. -""" -function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - if !isempty(get_values_shallow(choices)) || length(collect(get_submaps_shallow(choices))) > 2 - error("Not a pair") - end - a = get_submap(choices, key1) - b = get_submap(choices, key2) - (a, b) -end - -# TODO make a generated function? -function _fill_array!(choices::StaticChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for value in choices.leaf_nodes - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for node in choices.internal_nodes - n_written = _fill_array!(node, arr, idx) - idx += n_written - end - idx - start_idx -end - -@generated function _from_array( - proto_choices::StaticChoiceMap{R,S,T,U}, arr::Vector{V}, start_idx::Int) where {R,S,T,U,V} - leaf_node_keys = proto_choices.parameters[1] - leaf_node_types = proto_choices.parameters[2].parameters - internal_node_keys = proto_choices.parameters[3] - internal_node_types = proto_choices.parameters[4].parameters - - exprs = [quote idx = start_idx end] - leaf_node_names = [] - internal_node_names = [] - - # leaf nodes - for key in leaf_node_keys - value = gensym() - push!(leaf_node_names, value) - push!(exprs, quote - (n_read, $value) = _from_array(proto_choices.leaf_nodes.$key, arr, idx) - idx += n_read - end) - end - - # internal nodes - for key in internal_node_keys - node = gensym() - push!(internal_node_names, node) - push!(exprs, quote - (n_read, $node) = _from_array(proto_choices.internal_nodes.$key, arr, idx) - idx += n_read - end) - end - - quote - $(exprs...) - leaf_nodes_field = NamedTuple{R,S}(($(leaf_node_names...),)) - internal_nodes_field = NamedTuple{T,U}(($(internal_node_names...),)) - choices = StaticChoiceMap{R,S,T,U}(leaf_nodes_field, internal_nodes_field) - (idx - start_idx, choices) - end -end - -@generated function Base.merge(choices1::StaticChoiceMap{R,S,T,U}, - choices2::StaticChoiceMap{W,X,Y,Z}) where {R,S,T,U,W,X,Y,Z} - - # unpack first assignment type parameters - leaf_node_keys1 = choices1.parameters[1] - leaf_node_types1 = choices1.parameters[2].parameters - internal_node_keys1 = choices1.parameters[3] - internal_node_types1 = choices1.parameters[4].parameters - keys1 = (leaf_node_keys1..., internal_node_keys1...,) - - # unpack second assignment type parameters - leaf_node_keys2 = choices2.parameters[1] - leaf_node_types2 = choices2.parameters[2].parameters - internal_node_keys2 = choices2.parameters[3] - internal_node_types2 = choices2.parameters[4].parameters - keys2 = (leaf_node_keys2..., internal_node_keys2...,) - - # leaf vs leaf collision is an error - colliding_leaf_leaf_keys = intersect(leaf_node_keys1, leaf_node_keys2) - if !isempty(colliding_leaf_leaf_keys) - error("choices1 and choices2 both have leaf nodes at key(s): $colliding_leaf_leaf_keys") - end - - # leaf vs internal collision is an error - colliding_leaf_internal_keys = intersect(leaf_node_keys1, internal_node_keys2) - if !isempty(colliding_leaf_internal_keys) - error("choices1 has leaf node and choices2 has internal node at key(s): $colliding_leaf_internal_keys") - end - - # internal vs leaf collision is an error - colliding_internal_leaf_keys = intersect(internal_node_keys1, leaf_node_keys2) - if !isempty(colliding_internal_leaf_keys) - error("choices1 has internal node and choices2 has leaf node at key(s): $colliding_internal_leaf_keys") - end - - # internal vs internal collision is not an error, recursively call merge - colliding_internal_internal_keys = (intersect(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys1_exclusive = (setdiff(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys2_exclusive = (setdiff(internal_node_keys2, internal_node_keys1)...,) - - # leaf nodes named tuple - leaf_node_keys = (leaf_node_keys1..., leaf_node_keys2...,) - leaf_node_types = map(QuoteNode, (leaf_node_types1..., leaf_node_types2...,)) - leaf_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys1]..., - [Expr(:(.), :(choices2.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys2]...) - leaf_nodes = Expr(:call, - Expr(:curly, :NamedTuple, - QuoteNode(leaf_node_keys), - Expr(:curly, :Tuple, leaf_node_types...)), - leaf_node_values) - - # internal nodes named tuple - internal_node_keys = (internal_node_keys1_exclusive..., - internal_node_keys2_exclusive..., - colliding_internal_internal_keys...) - internal_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)) - for key in internal_node_keys1_exclusive]..., - [Expr(:(.), :(choices2.internal_nodes), QuoteNode(key)) - for key in internal_node_keys2_exclusive]..., - [Expr(:call, :merge, - Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)), - Expr(:(.), :(choices2.internal_nodes), QuoteNode(key))) - for key in colliding_internal_internal_keys]...) - internal_nodes = Expr(:call, - Expr(:curly, :NamedTuple, QuoteNode(internal_node_keys)), - internal_node_values) - - # construct assignment from named tuples - Expr(:call, :StaticChoiceMap, leaf_nodes, internal_nodes) -end - -export StaticChoiceMap -export pair, unpair - -####################### -# dynamic assignment # -####################### - -struct DynamicChoiceMap <: ChoiceMap - leaf_nodes::Dict{Any,Any} - internal_nodes::Dict{Any,Any} - function DynamicChoiceMap(leaf_nodes::Dict{Any,Any}, internal_nodes::Dict{Any,Any}) - new(leaf_nodes, internal_nodes) - end -end - -# invariant: all internal nodes are nonempty - -""" - struct DynamicChoiceMap <: ChoiceMap .. end - -A mutable map from arbitrary hierarchical addresses to values. - - choices = DynamicChoiceMap() - -Construct an empty map. - - choices = DynamicChoiceMap(tuples...) - -Construct a map containing each of the given (addr, value) tuples. -""" -function DynamicChoiceMap() - DynamicChoiceMap(Dict(), Dict()) -end - -function DynamicChoiceMap(tuples...) - choices = DynamicChoiceMap() - for (addr, value) in tuples - choices[addr] = value - end - choices -end - -""" - choices = DynamicChoiceMap(other::ChoiceMap) - -Copy a choice map, returning a mutable choice map. -""" -function DynamicChoiceMap(other::ChoiceMap) - choices = DynamicChoiceMap() - for (addr, val) in get_values_shallow(other) - choices[addr] = val - end - for (addr, submap) in get_submaps_shallow(other) - set_submap!(choices, addr, DynamicChoiceMap(submap)) - end - choices -end - -""" - choices = choicemap() - -Construct an empty mutable choice map. -""" -function choicemap() - DynamicChoiceMap() -end - -""" - choices = choicemap(tuples...) - -Construct a mutable choice map initialized with given address, value tuples. -""" -function choicemap(tuples...) - DynamicChoiceMap(tuples...) -end - -get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() - -get_values_shallow(choices::DynamicChoiceMap) = choices.leaf_nodes - -get_submaps_shallow(choices::DynamicChoiceMap) = choices.internal_nodes - -has_value(choices::DynamicChoiceMap, addr::Pair) = _has_value(choices, addr) - -get_value(choices::DynamicChoiceMap, addr::Pair) = _get_value(choices, addr) - -get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::DynamicChoiceMap, addr) - if haskey(choices.internal_nodes, addr) - choices.internal_nodes[addr] - elseif haskey(choices.leaf_nodes, addr) - throw(KeyError(addr)) - else - EmptyChoiceMap() - end -end - -has_value(choices::DynamicChoiceMap, addr) = haskey(choices.leaf_nodes, addr) - -get_value(choices::DynamicChoiceMap, addr) = choices.leaf_nodes[addr] - -function Base.isempty(choices::DynamicChoiceMap) - isempty(choices.leaf_nodes) && isempty(choices.internal_nodes) -end - -# mutation (not part of the assignment interface) - -""" - set_value!(choices::DynamicChoiceMap, addr, value) - -Set the given value for the given address. - -Will cause any previous value or sub-assignment at this address to be deleted. -It is an error if there is already a value present at some prefix of the given address. - -The following syntactic sugar is provided: - - choices[addr] = value -""" -function set_value!(choices::DynamicChoiceMap, addr, value) - delete!(choices.internal_nodes, addr) - choices.leaf_nodes[addr] = value -end - -function set_value!(choices::DynamicChoiceMap, addr::Pair, value) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - node = choices.internal_nodes[first] - set_value!(node, rest, value) -end - -""" - set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) - -Replace the sub-assignment rooted at the given address with the given sub-assignment. -Set the given value for the given address. - -Will cause any previous value or sub-assignment at the given address to be deleted. -It is an error if there is already a value present at some prefix of address. -""" -function set_submap!(choices::DynamicChoiceMap, addr, new_node) - delete!(choices.leaf_nodes, addr) - delete!(choices.internal_nodes, addr) - if !isempty(new_node) - choices.internal_nodes[addr] = new_node - end -end - -function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - set_submap!(node, rest, new_node) -end - -Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) - -function _fill_array!(choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - leaf_keys_sorted = sort(collect(keys(choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - value = choices.leaf_nodes[key] - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for key in internal_node_keys_sorted - n_written = _fill_array!(get_submap(choices, key), arr, idx) - idx += n_written - end - idx - start_idx -end - -function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - @assert length(arr) >= start_idx - choices = DynamicChoiceMap() - leaf_keys_sorted = sort(collect(keys(proto_choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(proto_choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - (n_read, value) = _from_array(proto_choices.leaf_nodes[key], arr, idx) - idx += n_read - choices.leaf_nodes[key] = value - end - for key in internal_node_keys_sorted - (n_read, node) = _from_array(get_submap(proto_choices, key), arr, idx) - idx += n_read - choices.internal_nodes[key] = node - end - (idx - start_idx, choices) -end - -export DynamicChoiceMap -export choicemap -export set_value! -export set_submap! - - -####################################### -## vector combinator for assignments # -####################################### - -# TODO implement LeafVectorChoiceMap, which stores a vector of leaf nodes - -struct InternalVectorChoiceMap{T} <: ChoiceMap - internal_nodes::Vector{T} - is_empty::Bool -end - -function vectorize_internal(nodes::Vector{T}) where {T} - is_empty = all(map(isempty, nodes)) - InternalVectorChoiceMap(nodes, is_empty) -end - -# note some internal nodes may be empty - -get_address_schema(::Type{InternalVectorChoiceMap}) = VectorAddressSchema() - -Base.isempty(choices::InternalVectorChoiceMap) = choices.is_empty -has_value(choices::InternalVectorChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::InternalVectorChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::InternalVectorChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::InternalVectorChoiceMap, addr::Int) - if addr > 0 && addr <= length(choices.internal_nodes) - choices.internal_nodes[addr] - else - EmptyChoiceMap() - end -end - -function get_submaps_shallow(choices::InternalVectorChoiceMap) - ((i, choices.internal_nodes[i]) - for i=1:length(choices.internal_nodes) - if !isempty(choices.internal_nodes[i])) -end - -get_values_shallow(::InternalVectorChoiceMap) = () - -function _fill_array!(choices::InternalVectorChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for key=1:length(choices.internal_nodes) - n = _fill_array!(choices.internal_nodes[key], arr, idx) - idx += n - end - idx - start_idx -end - -function _from_array(proto_choices::InternalVectorChoiceMap{U}, arr::Vector{T}, start_idx::Int) where {T,U} - @assert length(arr) >= start_idx - nodes = Vector{U}(undef, length(proto_choices.internal_nodes)) - idx = start_idx - for key=1:length(proto_choices.internal_nodes) - (n_read, nodes[key]) = _from_array(proto_choices.internal_nodes[key], arr, idx) - idx += n_read - end - choices = InternalVectorChoiceMap(nodes, proto_choices.is_empty) - (idx - start_idx, choices) -end - -export InternalVectorChoiceMap -export vectorize_internal - - -#################### -# empty assignment # -#################### - -struct EmptyChoiceMap <: ChoiceMap end - -Base.isempty(::EmptyChoiceMap) = true -get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() -get_submaps_shallow(::EmptyChoiceMap) = () -get_values_shallow(::EmptyChoiceMap) = () - -_fill_array!(::EmptyChoiceMap, arr::Vector, start_idx::Int) = 0 -_from_array(::EmptyChoiceMap, arr::Vector, start_idx::Int) = (0, EmptyChoiceMap()) - -export EmptyChoiceMap - -############################################ -# Nested-dict–like accessor for choicemaps # -############################################ - -""" -Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than -the default syntax which looks like a flat dict of full keypaths. - -```jldoctest -julia> using Gen -julia> c = choicemap((:a, 1), - (:b => :c, 2)); -julia> cv = nested_view(c); -julia> c[:a] == cv[:a] -true -julia> c[:b => :c] == cv[:b][:c] -true -julia> length(cv) -2 -julia> length(cv[:b]) -1 -julia> sort(collect(keys(cv))) -[:a, :b] -julia> sort(collect(keys(cv[:b]))) -[:c] -``` -""" -struct ChoiceMapNestedView - choice_map::ChoiceMap -end - -function Base.getindex(choices::ChoiceMapNestedView, addr) - if has_value(choices.choice_map, addr) - return get_value(choices.choice_map, addr) - end - submap = get_submap(choices.choice_map, addr) - if isempty(submap) - throw(KeyError(addr)) - end - ChoiceMapNestedView(submap) -end - -function Base.iterate(c::ChoiceMapNestedView) - inner_iterator = Base.Iterators.flatten(( - get_values_shallow(c.choice_map), - ((k, ChoiceMapNestedView(v)) - for (k, v) in get_submaps_shallow(c.choice_map)))) - r = Base.iterate(inner_iterator) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -function Base.iterate(c::ChoiceMapNestedView, state) - (inner_iterator, inner_state) = state - r = Base.iterate(inner_iterator, inner_state) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -# TODO: Allow different implementations of this method depending on the -# concrete type of the `ChoiceMap`, so that an already-existing data structure -# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it -# exists. -Base.keys(cv::Gen.ChoiceMapNestedView) = (k for (k, v) in cv) - -function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) - a.choice_map == b.choice_map -end - -# Length of a `ChoiceMapNestedView` is number of leaf values + number of -# submaps. Motivation: This matches what `length` would return for the -# equivalent nested dict. -function Base.length(cv::ChoiceMapNestedView) - +(get_values_shallow(cv.choice_map) |> collect |> length, - get_submaps_shallow(cv.choice_map) |> collect |> length) -end - -function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) - Base.show(io, MIME"text/plain"(), c.choice_map) -end - -nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) - -# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling -# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and -# aux data together. - -export nested_view - -""" - selected_choices = get_selected(choices::ChoiceMap, selection::Selection) - -Filter the choice map to include only choices in the given selection. - -Returns a new choice map. -""" -function get_selected( - choices::ChoiceMap, selection::Selection) - output = choicemap() - for (key, value) in get_values_shallow(choices) - if (key in selection) - output[key] = value - end - end - for (key, submap) in get_submaps_shallow(choices) - subselection = selection[key] - set_submap!(output, key, get_selected(submap, subselection)) - end - output -end - -export get_selected diff --git a/src/choice_map2/array_interface.jl b/src/choice_map/array_interface.jl similarity index 83% rename from src/choice_map2/array_interface.jl rename to src/choice_map/array_interface.jl index f88c5b116..cf9d0bd03 100644 --- a/src/choice_map2/array_interface.jl +++ b/src/choice_map/array_interface.jl @@ -34,12 +34,20 @@ function to_array(choices::ChoiceMap, ::Type{T}) where {T} end function _fill_array!(c::ValueChoiceMap{<:T}, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) <: start_idx + if length(arr) < start_idx resize!(arr, 2 * start_idx) end arr[start_idx] = get_value(c) 1 end +function _fill_array!(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + value = get_value(c) + if length(arr) < start_idx + length(value) + resize!(arr, 2 * (start_idx + length(value))) + end + arr[start_idx:start_idx+length(value)-1] = value + length(value) +end # default _fill_array! implementation function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} @@ -88,5 +96,11 @@ function from_array(proto_choices::ChoiceMap, arr::Vector) end function _from_array(::ValueChoiceMap, arr::Vector, start_idx::Int) - ValueChoiceMap(arr[start_idx]) -end \ No newline at end of file + (1, ValueChoiceMap(arr[start_idx])) +end +function _from_array(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + n_read = length(get_value(c)) + (n_read, ValueChoiceMap(arr[start_idx:start_idx+n_read-1])) +end + +export to_array, from_array \ No newline at end of file diff --git a/src/choice_map2/choice_map.jl b/src/choice_map/choice_map.jl similarity index 82% rename from src/choice_map2/choice_map.jl rename to src/choice_map/choice_map.jl index 0ebb19f09..402cefa37 100644 --- a/src/choice_map2/choice_map.jl +++ b/src/choice_map/choice_map.jl @@ -2,6 +2,22 @@ # choice map interface # ######################### +""" + ChoiceMapGetValueError + +The error returned when a user attempts to call `get_value` +on an choicemap for an address which does not contain a value in that choicemap. +""" +struct ChoiceMapGetValueError <: Exception end +showerror(io::IO, ex::ChoiceMapGetValueError) = (print(io, "ChoiceMapGetValueError: no value was found for the `get_value` call.")) + +""" + abstract type ChoiceMap end + +Abstract type for maps from hierarchical addresses to values. +""" +abstract type ChoiceMap end + """ get_submaps_shallow(choices::ChoiceMap) @@ -26,7 +42,6 @@ function get_submap end submap = get_submap(choices, first) get_submap(submap, rest) end -@inline get_submap(choices::ChoiceMap, addr::Pair) = _get_submap(choices, addr) """ has_value(choices::ChoiceMap) @@ -45,18 +60,18 @@ function has_value end get_value(choices::ChoiceMap) Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; -throws a `KeyError` if `choices` is not a `ValueChoiceMap`. +throws a `ChoiceMapGetValueError` if `choices` is not a `ValueChoiceMap`. get_value(choices::ChoiceMap, addr) Returns the value stored in the submap with address `addr` or throws -a `KeyError` if no value exists at this address. +a `ChoiceMapGetValueError` if no value exists at this address. A syntactic sugar is `Base.getindex`: value = choices[addr] """ function get_value end -get_value(::ChoiceMap) = throw(KeyError(nothing)) +get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) @@ -73,6 +88,8 @@ function get_address_schema end Returns an iterable collection of tuples `(address, value)` for each value stored at a top-level address in `choices`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) """ function get_values_shallow(choices::ChoiceMap) ( @@ -88,20 +105,15 @@ end Returns an iterable collection of tuples `(address, submap)` for every top-level submap stored in `choices` which is not a `ValueChoiceMap`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) """ function get_nonvalue_submaps_shallow(choices::ChoiceMap) - filter(! ∘ has_value, get_submaps_shallow(choices)) + (addr_to_submap for addr_to_submap in get_submaps_shallow(choices) if !has_value(addr_to_submap[2])) end # a choicemap is empty if it has no submaps and no value -Base.isempty(c::ChoiceMap) = isempty(get_submaps_shallow(c)) && !has_value(c) - -""" - abstract type ChoiceMap end - -Abstract type for maps from hierarchical addresses to values. -""" -abstract type ChoiceMap end +Base.isempty(c::ChoiceMap) = all(((addr, submap),) -> isempty(submap), get_submaps_shallow(c)) && !has_value(c) """ EmptyChoiceMap @@ -111,11 +123,14 @@ A choicemap with no submaps or values. struct EmptyChoiceMap <: ChoiceMap end @inline has_value(::EmptyChoiceMap, addr...) = false -@inline get_value(::EmptyChoiceMap) = throw(KeyError(nothing)) +@inline get_value(::EmptyChoiceMap) = throw(ChoiceMapGetValueError()) @inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() @inline Base.isempty(::EmptyChoiceMap) = true @inline get_submaps_shallow(::EmptyChoiceMap) = () @inline get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() +@inline Base.:(==)(::EmptyChoiceMap, ::EmptyChoiceMap) = true +@inline Base.:(==)(::ChoiceMap, ::EmptyChoiceMap) = false +@inline Base.:(==)(::EmptyChoiceMap, ::ChoiceMap) = false """ ValueChoiceMap @@ -148,6 +163,11 @@ function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) for (key, submap) in get_submaps_shallow(choices1) set_submap!(choices, key, merge(submap, get_submap(choices2, key))) end + for (key, submap) in get_submaps_shallow(choices2) + if isempty(get_submap(choices1, key)) + set_submap!(choices, key, submap) + end + end choices end Base.merge(c::ChoiceMap, ::EmptyChoiceMap) = c @@ -170,6 +190,11 @@ function Base.:(==)(a::ChoiceMap, b::ChoiceMap) return false end end + for (addr, submap) in get_submaps_shallow(b) + if get_submap(a, addr) != submap + return false + end + end return true end @@ -246,9 +271,11 @@ function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) end export ChoiceMap, ValueChoiceMap, EmptyChoiceMap -export get_submap, get_submaps_shallow +export _get_submap, get_submap, get_submaps_shallow export get_value, has_value export get_values_shallow, get_nonvalue_submaps_shallow +export get_address_schema, get_selected +export ChoiceMapGetValueError include("array_interface.jl") include("dynamic_choice_map.jl") diff --git a/src/choice_map2/dynamic_choice_map.jl b/src/choice_map/dynamic_choice_map.jl similarity index 93% rename from src/choice_map2/dynamic_choice_map.jl rename to src/choice_map/dynamic_choice_map.jl index 5dfca0b55..a3403307c 100644 --- a/src/choice_map2/dynamic_choice_map.jl +++ b/src/choice_map/dynamic_choice_map.jl @@ -2,10 +2,6 @@ # dynamic assignment # ####################### -struct DynamicChoiceMap <: ChoiceMap - submaps::Dict{Any, <:ChoiceMap} -end - """ struct DynamicChoiceMap <: ChoiceMap .. end @@ -19,8 +15,11 @@ Construct an empty map. Construct a map containing each of the given (addr, value) tuples. """ -function DynamicChoiceMap() - DynamicChoiceMap(Dict()) +struct DynamicChoiceMap <: ChoiceMap + submaps::Dict{Any, ChoiceMap} + function DynamicChoiceMap() + new(Dict()) + end end function DynamicChoiceMap(tuples...) @@ -39,12 +38,13 @@ Copy a choice map, returning a mutable choice map. function DynamicChoiceMap(other::ChoiceMap) choices = DynamicChoiceMap() for (addr, submap) in get_submaps_shallow(other) - if choices isa ValueChoiceMap + if submap isa ValueChoiceMap set_submap!(choices, addr, submap) else set_submap!(choices, addr, DynamicChoiceMap(submap)) end end + choices end DynamicChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a DynamicChoiceMap") @@ -116,14 +116,14 @@ Set the given value for the given address. Will cause any previous value or sub-assignment at the given address to be deleted. It is an error if there is already a value present at some prefix of address. """ -function set_submap!(choices::DynamicChoiceMap, addr, new_node) +function set_submap!(choices::DynamicChoiceMap, addr, new_node::ChoiceMap) delete!(choices.submaps, addr) if !isempty(new_node) choices.submaps[addr] = new_node end end -function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) +function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node::ChoiceMap) (first, rest) = addr if !haskey(choices.submaps, first) choices.submaps[first] = DynamicChoiceMap() @@ -137,7 +137,7 @@ Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, add function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} choices = DynamicChoiceMap() - keys_sorted = sort(collect(keys(choices.submaps))) + keys_sorted = sort(collect(keys(proto_choices.submaps))) idx = start_idx for key in keys_sorted (n_read, submap) = _from_array(proto_choices.submaps[key], arr, idx) diff --git a/src/choice_map2/nested_view.jl b/src/choice_map/nested_view.jl similarity index 93% rename from src/choice_map2/nested_view.jl rename to src/choice_map/nested_view.jl index 6693234fb..68add0a05 100644 --- a/src/choice_map2/nested_view.jl +++ b/src/choice_map/nested_view.jl @@ -33,7 +33,7 @@ ChoiceMapNestedView(cm::ValueChoiceMap) = get_value(cm) ChoiceMapNestedView(::EmptyChoiceMap) = error("Can't convert an emptychoicemap to nested view.") function Base.getindex(choices::ChoiceMapNestedView, addr) - ChoiceMapNestedView(get_submap(choices, addr)) + ChoiceMapNestedView(get_submap(choices.choice_map, addr)) end function Base.iterate(c::ChoiceMapNestedView) @@ -62,9 +62,8 @@ end # exists. Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) -function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) - a.choice_map = b.choice_map -end +Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) = a.choice_map == b.choice_map + function Base.length(cv::ChoiceMapNestedView) length(collect(get_submaps_shallow(cv.choice_map))) end diff --git a/src/choice_map2/static_choice_map.jl b/src/choice_map/static_choice_map.jl similarity index 68% rename from src/choice_map2/static_choice_map.jl rename to src/choice_map/static_choice_map.jl index 3508762d7..1f75b3bca 100644 --- a/src/choice_map2/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -4,13 +4,21 @@ struct StaticChoiceMap{Addrs, SubmapTypes} <: ChoiceMap submaps::NamedTuple{Addrs, SubmapTypes} + function StaticChoiceMap(submaps::NamedTuple{Addrs, SubmapTypes}) where {Addrs, SubmapTypes <: NTuple{n, ChoiceMap} where n} + new{Addrs, SubmapTypes}(submaps) + end +end + +function StaticChoiceMap(;addrs_to_vals_and_maps...) + addrs = Tuple(addr for (addr, val_or_map) in addrs_to_vals_and_maps) + maps = Tuple(val_or_map isa ChoiceMap ? val_or_map : ValueChoiceMap(val_or_map) for (addr, val_or_map) in addrs_to_vals_and_maps) + StaticChoiceMap(NamedTuple{addrs}(maps)) end @inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) @inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) @inline get_submap(choices::StaticChoiceMap, addr::Symbol) = static_get_submap(choices, Val(addr)) -# TODO: profiling! @generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} if A in Addrs quote choices.submaps[A] end @@ -18,17 +26,25 @@ end quote EmptyChoiceMap() end end end +static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) +static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) # convert a nonvalue choicemap all of whose top-level-addresses # are symbols into a staticchoicemap at the top level function StaticChoiceMap(other::ChoiceMap) - keys_and_nodes = get_submaps_shallow(other) - (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + keys_and_nodes = collect(get_submaps_shallow(other)) + if length(keys_and_nodes) > 0 + (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + else + addrs = () + submaps = () + end StaticChoiceMap(NamedTuple{addrs}(submaps)) end StaticChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a StaticChoiceMap") +StaticChoiceMap(::NamedTuple{(),Tuple{}}) = EmptyChoiceMap() # TODO: deep conversion to static choicemap @@ -58,9 +74,9 @@ end @generated function Base.merge(choices1::StaticChoiceMap{Addrs1, SubmapTypes1}, choices2::StaticChoiceMap{Addrs2, SubmapTypes2}) where {Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} - - addr_to_type1 = Dict{Symbol, ::Type{<:ChoiceMap}}() - addr_to_type2 = Dict{Symbol, ::Type{<:ChoiceMap}}() + + addr_to_type1 = Dict{Symbol, Type{<:ChoiceMap}}() + addr_to_type2 = Dict{Symbol, Type{<:ChoiceMap}}() for (i, addr) in enumerate(Addrs1) addr_to_type1[addr] = SubmapTypes1.parameters[i] end @@ -78,30 +94,30 @@ end || (type2 <: ValueChoiceMap && type1 != EmptyChoiceMap)) error( "One choicemap has a value at address $addr; the other is nonempty at $addr. Cannot merge.") end - if type1 <: ValueChoiceMap + if type1 <: EmptyChoiceMap push!(submap_exprs, - quote choices1.submaps[$addr] end + quote choices2.submaps.$addr end ) - elseif type2 <: ValueChoiceMap + elseif type2 <: EmptyChoiceMap push!(submap_exprs, - quote choices2.submaps[$addr] end + quote choices1.submaps.$addr end ) else push!(submap_exprs, - quote merge(choices1.submaps[$addr], choices2.submaps[$addr]) end + quote merge(choices1.submaps.$addr, choices2.submaps.$addr) end ) end end quote - StaticChoiceMap{$merged_addrs}(submap_exprs...) + StaticChoiceMap(NamedTuple{$merged_addrs}(($(submap_exprs...),))) end end -@generated function _from_array!(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, +@generated function _from_array(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, arr::Vector{T}, start_idx::Int) where {T, Addrs, SubmapTypes} - perm = sortperm(Addrs) + perm = sortperm(collect(Addrs)) sorted_addrs = Addrs[perm] submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) @@ -112,7 +128,7 @@ end submap_var_names[idx] = submap_var_name push!(exprs, quote - (n_read, submap_var_name = _from_array(proto_choices.submaps[$addr], arr, idx) + (n_read, $submap_var_name) = _from_array(proto_choices.submaps.$addr, arr, idx) idx += n_read end ) @@ -120,14 +136,14 @@ end quote $(exprs...) - submaps = NamedTuple{Addrs}(( $(submap_var_names...) )) - choices = StaticChoiceMap{Addrs, SubmapTypes}(submaps) + submaps = NamedTuple{Addrs}(( $(submap_var_names...), )) + choices = StaticChoiceMap(submaps) (idx - start_idx, choices) end end function get_address_schema(::Type{StaticChoiceMap{Addrs, SubmapTypes}}) where {Addrs, SubmapTypes} - StaticAddressSchema(set(Addrs)) + StaticAddressSchema(Set(Addrs)) end export StaticChoiceMap diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index c6f09374c..73f22159a 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -124,42 +124,33 @@ function visit!(visitor::AddressVisitor, addr) push!(visitor.visited, addr) end +all_visited(::Selection, ::ValueChoiceMap) = false +all_visited(::AllSelection, ::ValueChoiceMap) = true function all_visited(visited::Selection, choices::ChoiceMap) - allvisited = true - for (key, _) in get_values_shallow(choices) - allvisited = allvisited && (key in visited) - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - allvisited = allvisited && all_visited(subvisited, submap) + if !all_visited(visited[key], submap) + return false end end - allvisited + return true end +get_unvisited(::Selection, v::ValueChoiceMap) = v +get_unvisited(::AllSelection, v::ValueChoiceMap) = EmptyChoiceMap() function get_unvisited(visited::Selection, choices::ChoiceMap) unvisited = choicemap() - for (key, _) in get_values_shallow(choices) - if !(key in visited) - set_value!(unvisited, key, get_value(choices, key)) - end - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - sub_unvisited = get_unvisited(subvisited, submap) - set_submap!(unvisited, key, sub_unvisited) - end + sub_unvisited = get_unvisited(visited[key], submap) + set_submap!(unvisited, key, sub_unvisited) end unvisited end get_visited(visitor) = visitor.visited -function check_no_submap(constraints::ChoiceMap, addr) +function check_is_empty(constraints::ChoiceMap, addr) if !isempty(get_submap(constraints, addr)) - error("Expected a value at address $addr but found a sub-assignment") + error("Expected a value or EmptyChoiceMap at address $addr but found a sub-assignment") end end diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index df6a5f465..970dac42d 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -20,7 +20,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T}, # check for constraints at this key constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) + !constrained && check_is_empty(state.constraints, key) # get return value if constrained diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 8c02eceb5..882297e43 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -119,9 +119,6 @@ struct DynamicDSLChoiceMap <: ChoiceMap end get_address_schema(::Type{DynamicDSLChoiceMap}) = DynamicAddressSchema() -Base.isempty(::DynamicDSLChoiceMap) = false # TODO not necessarily true -has_value(choices::DynamicDSLChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::DynamicDSLChoiceMap, addr::Pair) = _get_value(choices, addr) get_submap(choices::DynamicDSLChoiceMap, addr::Pair) = _get_submap(choices, addr) function get_submap(choices::DynamicDSLChoiceMap, addr) @@ -130,9 +127,10 @@ function get_submap(choices::DynamicDSLChoiceMap, addr) # leaf node, must be a call call = trie[addr] if call.is_choice - throw(KeyError(addr)) + ValueChoiceMap(call.subtrace_or_retval) + else + get_choices(call.subtrace_or_retval) end - get_choices(call.subtrace_or_retval) elseif has_internal_node(trie, addr) # internal node subtrie = get_internal_node(trie, addr) @@ -142,32 +140,12 @@ function get_submap(choices::DynamicDSLChoiceMap, addr) end end -function has_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - has_leaf_node(trie, addr) && trie[addr].is_choice -end - -function get_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - choice = trie[addr] - if !choice.is_choice - throw(KeyError(addr)) - end - choice.subtrace_or_retval -end - -function get_values_shallow(choices::DynamicDSLChoiceMap) - ((key, choice.subtrace_or_retval) - for (key, choice) in get_leaf_nodes(choices.trie) - if choice.is_choice) -end - function get_submaps_shallow(choices::DynamicDSLChoiceMap) - calls_iter = ((key, get_choices(call.subtrace_or_retval)) + calls_iter = ( + (key, call.is_choice ? ValueChoiceMap(call.subtrace_or_retval) : get_choices(call.subtrace_or_retval)) for (key, call) in get_leaf_nodes(choices.trie) - if !call.is_choice) - internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) - for (key, trie) in get_internal_nodes(choices.trie)) + ) + internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) Iterators.flatten((calls_iter, internal_nodes_iter)) end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 24e023f24..7acc16302 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -35,7 +35,7 @@ function traceat(state::GFUpdateState, dist::Distribution{T}, # check for constraints at this key constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) + !constrained && check_is_empty(state.constraints, key) # record the previous value as discarded if it is replaced if constrained && has_previous @@ -149,32 +149,22 @@ end function add_unvisited_to_discard!(discard::DynamicChoiceMap, visited::DynamicSelection, prev_choices::ChoiceMap) - for (key, value) in get_values_shallow(prev_choices) + for (key, submap) in get_submaps_shallow(prev_choices) + # if key IS in visited, + # the recursive call to update already handled the discard + # for this entire submap; else we need to handle it if !(key in visited) - @assert !has_value(discard, key) @assert isempty(get_submap(discard, key)) - set_value!(discard, key, value) - end - end - for (key, submap) in get_submaps_shallow(prev_choices) - @assert !has_value(discard, key) - if key in visited - # the recursive call to update already handled the discard - # for this entire submap - continue - else subvisited = visited[key] if isempty(subvisited) # none of this submap was visited, so we discard the whole thing - @assert isempty(get_submap(discard, key)) set_submap!(discard, key, submap) else subdiscard = get_submap(discard, key) - add_unvisited_to_discard!( - isempty(subdiscard) ? choicemap() : subdiscard, - subvisited, submap) + subdiscard = isempty(subdiscard) ? choicemap() : subdiscard + add_unvisited_to_discard!(subdiscard, subvisited, submap) set_submap!(discard, key, subdiscard) - end + end end end end diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index a231f03a7..d662dbb75 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -1,12 +1,13 @@ import MacroTools function check_observations(choices::ChoiceMap, observations::ChoiceMap) - for (key, value) in get_values_shallow(observations) - !has_value(choices, key) && error("Check failed: observed choice at $key not found") - choices[key] != value && error("Check failed: value of observed choice at $key changed") - end for (key, submap) in get_submaps_shallow(observations) - check_observations(get_submap(choices, key), submap) + if has_value(submap) + !has_value(choices, key) && error("Check failed: observed choice at $key not found") + choices[key] != value && error("Check failed: value of observed choice at $key changed") + else + check_observations(get_submap(choices, key), submap) + end end end diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 234116976..f17d061f8 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -14,10 +14,7 @@ function get_submap(choices::CallAtChoiceMap{K,T}, addr::K) where {K,T} end get_submap(choices::CallAtChoiceMap, addr::Pair) = _get_submap(choices, addr) -get_value(choices::CallAtChoiceMap, addr::Pair) = _get_value(choices, addr) -has_value(choices::CallAtChoiceMap, addr::Pair) = _has_value(choices, addr) get_submaps_shallow(choices::CallAtChoiceMap) = ((choices.key, choices.submap),) -get_values_shallow(::CallAtChoiceMap) = () # TODO optimize CallAtTrace using type parameters @@ -69,7 +66,7 @@ unpack_call_at_args(args) = (args[end], args[1:end-1]) function assess(gen_fn::CallAtCombinator, args::Tuple, choices::ChoiceMap) (key, kernel_args) = unpack_call_at_args(args) - if length(get_submaps_shallow(choices)) > 1 || length(get_values_shallow(choices)) > 0 + if length(get_submaps_shallow(choices)) > 1 error("Not all constraints were consumed") end submap = get_submap(choices, key) diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl index 69bb4851a..f38758956 100644 --- a/src/modeling_library/choice_at/choice_at.jl +++ b/src/modeling_library/choice_at/choice_at.jl @@ -25,10 +25,12 @@ function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap} end get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr) has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr) +get_submap(choices::ChoiceAtChoiceMap, addr::Pair) = _get_submap(choices, addr) function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K} choices.key == addr ? choices.value : throw(KeyError(choices, addr)) end -get_submaps_shallow(choices::ChoiceAtChoiceMap) = () +get_submap(choices::ChoiceAtChoiceMap, addr) = addr == choices.key ? ValueChoiceMap(choices.value) : EmptyChoiceMap() +get_submaps_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, ValueChoiceMap(choices.value)),) get_values_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, choices.value),) struct ChoiceAtCombinator{T,K} <: GenerativeFunction{T, ChoiceAtTrace} diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 715800737..1f1017251 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -84,17 +84,7 @@ function get_submap(choices::RecurseTraceChoiceMap, end end -function get_submap(choices::RecurseTraceChoiceMap, addr::Pair) - _get_submap(choices, addr) -end - -function has_value(choices::RecurseTraceChoiceMap, addr::Pair) - _has_value(choices, addr) -end - -function get_value(choices::RecurseTraceChoiceMap, addr::Pair) - _get_value(choices, addr) -end +get_submap(choices::RecurseTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) get_values_shallow(choices::RecurseTraceChoiceMap) = () @@ -333,6 +323,9 @@ function recurse_unpack_constraints(constraints::ChoiceMap) production_constraints = Dict{Int, Any}() aggregation_constraints = Dict{Int, Any}() for (addr, node) in get_submaps_shallow(constraints) + if has_value(node) + error("Unknown address: $(addr)") + end idx::Int = addr[1] if addr[2] == Val(:production) production_constraints[idx] = node @@ -342,9 +335,6 @@ function recurse_unpack_constraints(constraints::ChoiceMap) error("Unknown address: $addr") end end - if length(get_values_shallow(constraints)) > 0 - error("Unknown address: $(first(get_values_shallow(constraints))[1])") - end return (production_constraints, aggregation_constraints) end diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index 9b0eb763a..3af416ef8 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -92,10 +92,6 @@ end end @inline get_submap(choices::VectorTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) -@inline get_value(choices::VectorTraceChoiceMap, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::VectorTraceChoiceMap, addr::Pair) = _has_value(choices, addr) -@inline get_values_shallow(::VectorTraceChoiceMap) = () - ############################################ # code shared by vector-shaped combinators # diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 7a0fe384e..b352d3ca2 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -330,21 +330,22 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, value_trie::Symbol, gradient_trie::Symbol) selected_choices_vec = collect(selected_choices) quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) - leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec) - leaf_gradients = map((node) -> gradient_var(node), selected_choices_vec) + leaf_value_choicemaps = map((node) -> :(ValueChoiceMap(trace.$(get_value_fieldname(node)))), selected_choices_vec) + leaf_gradient_choicemaps = map((node) -> :(ValueChoiceMap($(gradient_var(node)))), selected_choices_vec) selected_calls_vec = collect(selected_calls) quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec) - internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), + internal_value_choicemaps = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), selected_calls_vec) - internal_gradients = map((node) -> gradient_trie_var(node), selected_calls_vec) + internal_gradient_choicemaps = map((node) -> gradient_trie_var(node), selected_calls_vec) + + quoted_all_keys = Iterators.flatten((quoted_leaf_keys, quoted_internal_keys)) + all_value_choicemaps = Iterators.flatten((leaf_value_choicemaps, internal_value_choicemaps)) + all_gradient_choicemaps = Iterators.flatten((leaf_gradient_choicemaps, internal_gradient_choicemaps)) + quote - $value_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),))) - $gradient_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),))) + $value_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_value_choicemaps...),))) + $gradient_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_gradient_choicemaps...),))) end end diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 713c0863a..5ac3ced16 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -9,25 +9,8 @@ end function get_schema end @inline get_address_schema(::Type{StaticIRTraceAssmt{T}}) where {T} = get_schema(T) - @inline Base.isempty(choices::StaticIRTraceAssmt) = isempty(choices.trace) - -@inline static_has_value(choices::StaticIRTraceAssmt, key) = false - -@inline function get_value(choices::StaticIRTraceAssmt, key::Symbol) - static_get_value(choices, Val(key)) -end - -@inline function has_value(choices::StaticIRTraceAssmt, key::Symbol) - static_has_value(choices, Val(key)) -end - -@inline function get_submap(choices::StaticIRTraceAssmt, key::Symbol) - static_get_submap(choices, Val(key)) -end - -@inline get_value(choices::StaticIRTraceAssmt, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::StaticIRTraceAssmt, addr::Pair) = _has_value(choices, addr) +@inline get_submap(choices::StaticIRTraceAssmt, key::Symbol) = static_get_submap(choices, Val(key)) @inline get_submap(choices::StaticIRTraceAssmt, addr::Pair) = _get_submap(choices, addr) ######################### @@ -36,16 +19,13 @@ end abstract type StaticIRTrace <: Trace end -@inline function static_get_subtrace(trace::StaticIRTrace, addr) - error("Not implemented") -end +@inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) -@inline function Base.getindex(trace::StaticIRTrace, addr) - Gen.static_getindex(trace, Val(addr)) -end +@inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_getindex(trace, Val(addr)) @inline function Base.getindex(trace::StaticIRTrace, addr::Pair) first, rest = addr return Gen.static_get_subtrace(trace, Val(first))[rest] @@ -161,21 +141,13 @@ function generate_get_choices(trace_struct_name::Symbol) :($(QuoteNode(EmptyChoiceMap))()))) end -function generate_get_values_shallow(ir::StaticIR, trace_struct_name::Symbol) +function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] for node in ir.choice_nodes addr = node.addr value = :(choices.trace.$(get_value_fieldname(node))) - push!(elements, :(($(QuoteNode(addr)), $value))) + push!(elements, :(($(QuoteNode(addr)), ValueChoiceMap($value)))) end - Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_values_shallow)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name})), - Expr(:block, Expr(:tuple, elements...))) -end - -function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) - elements = [] for node in ir.call_nodes addr = node.addr subtrace = :(choices.trace.$(get_subtrace_fieldname(node))) @@ -224,30 +196,6 @@ function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) return [get_subtrace_exprs; call_getindex_exprs; choice_getindex_exprs] end -function generate_static_get_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_value)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(choices.trace.$(get_value_fieldname(node)))))) - end - methods -end - -function generate_static_has_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_has_value)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(true)))) - end - methods -end - function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) methods = Expr[] for node in ir.call_nodes @@ -259,13 +207,13 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) :(get_choices(choices.trace.$(get_subtrace_fieldname(node))))))) end - # throw a KeyError if get_submap is run on an address containing a value + # return a ValueChoiceMap if get_submap is run on an address containing a value for node in ir.choice_nodes push!(methods, Expr(:function, Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_submap)), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(throw(KeyError($(QuoteNode(node.addr)))))))) + Expr(:block, :(ValueChoiceMap(choices.trace.$(get_value_fieldname(node))))))) end methods end @@ -290,18 +238,13 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_retval_expr = generate_get_retval(ir, trace_struct_name) get_choices_expr = generate_get_choices(trace_struct_name) get_schema_expr = generate_get_schema(ir, trace_struct_name) - get_values_shallow_expr = generate_get_values_shallow(ir, trace_struct_name) get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name) - static_get_value_exprs = generate_static_get_value(ir, trace_struct_name) - static_has_value_exprs = generate_static_has_value(ir, trace_struct_name) static_get_submap_exprs = generate_static_get_submap(ir, trace_struct_name) getindex_exprs = generate_getindex(ir, trace_struct_name) exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr, get_args_expr, get_retval_expr, - get_choices_expr, get_schema_expr, get_values_shallow_expr, - get_submaps_shallow_expr, static_get_value_exprs..., - static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...) + get_choices_expr, get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) (exprs, trace_struct_name) end diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index dc4fddf31..c806bba3a 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -454,9 +454,10 @@ function generate_discard!(stmts::Vector{Expr}, end leaf_keys = map((key::Symbol) -> QuoteNode(key), leaf_keys) internal_keys = map((key::Symbol) -> QuoteNode(key), internal_keys) - expr = :($(QuoteNode(StaticChoiceMap))( - $(QuoteNode(NamedTuple)){($(leaf_keys...),)}(($(leaf_nodes...),)), - $(QuoteNode(NamedTuple)){($(internal_keys...),)}(($(internal_nodes...),)))) + all_keys = (leaf_keys..., internal_keys...) + all_nodes = ([:($(QuoteNode(ValueChoiceMap))($node)) for node in leaf_nodes]..., internal_nodes...) + expr = quote $(QuoteNode(StaticChoiceMap))( + $(QuoteNode(NamedTuple)){($(all_keys...),)}(($(all_nodes...),))) end push!(stmts, :($discard = $expr)) end diff --git a/test/assignment.jl b/test/assignment.jl index 1bba754af..1d7e48a80 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -1,6 +1,46 @@ +@testset "ValueChoiceMap" begin + vcm1 = ValueChoiceMap(2) + vcm2 = ValueChoiceMap(2.) + vcm3 = ValueChoiceMap([1,2]) + @test vcm1 isa ValueChoiceMap{Int} + @test vcm2 isa ValueChoiceMap{Float64} + @test vcm3 isa ValueChoiceMap{Vector{Int}} + + @test !isempty(vcm1) + @test has_value(vcm1) + @test get_value(vcm1) == 2 + @test vcm1 == vcm2 + @test isempty(get_submaps_shallow(vcm1)) + @test isempty(get_values_shallow(vcm1)) + @test isempty(get_nonvalue_submaps_shallow(vcm1)) + @test to_array(vcm1, Int) == [2] + @test from_array(vcm1, [4]) == ValueChoiceMap(4) + @test from_array(vcm3, [4, 5]) == ValueChoiceMap([4, 5]) + @test_throws Exception merge(vcm1, vcm2) + @test_throws Exception merge(vcm1, choicemap(:a, 5)) + @test merge(vcm1, EmptyChoiceMap()) == vcm1 + @test merge(EmptyChoiceMap(), vcm1) == vcm1 + @test get_submap(vcm1, :addr) == EmptyChoiceMap() + @test_throws ChoiceMapGetValueError get_value(vcm1, :addr) + @test !has_value(vcm1, :addr) + @test isapprox(vcm2, ValueChoiceMap(prevfloat(2.))) + @test isapprox(vcm1, ValueChoiceMap(prevfloat(2.))) + @test get_address_schema(typeof(vcm1)) == EmptyAddressSchema() + @test get_address_schema(ValueChoiceMap) == EmptyAddressSchema() + @test nested_view(vcm1) == 2 +end + +@testset "static choicemap constructor" begin + @test StaticChoiceMap((a=ValueChoiceMap(5), b=ValueChoiceMap(6))) == StaticChoiceMap(a=5, b=6) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + @test submap == StaticChoiceMap((a=ValueChoiceMap(1.), b=ValueChoiceMap([2., 2.5]))) + outer = StaticChoiceMap(c=3, d=submap, e=submap) + @test outer == StaticChoiceMap((c=ValueChoiceMap(3), d=submap, e=submap)) +end + @testset "static assignment to/from array" begin - submap = StaticChoiceMap((a=1., b=[2., 2.5]),NamedTuple()) - outer = StaticChoiceMap((c=3.,), (d=submap, e=submap)) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + outer = StaticChoiceMap(c=3., d=submap, e=submap) arr = to_array(outer, Float64) @test to_array(outer, Float64) == Float64[3.0, 1.0, 2.0, 2.5, 1.0, 2.0, 2.5] @@ -11,14 +51,16 @@ @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment to/from array" begin @@ -39,14 +81,18 @@ end @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test get_submap(choices, :c) == ValueChoiceMap(1.0) + @test get_submap(choices, :d => :b) == ValueChoiceMap([3.0, 4.0]) + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment copy constructor" begin @@ -64,25 +110,6 @@ end @test choices[:u => :w] == 4 end -@testset "internal vector assignment to/from array" begin - inner = choicemap() - set_value!(inner, :a, 1.) - set_value!(inner, :b, 2.) - outer = vectorize_internal([inner, inner, inner]) - - arr = to_array(outer, Float64) - @test to_array(outer, Float64) == Float64[1, 2, 1, 2, 1, 2] - - choices = from_array(outer, Float64[1, 2, 3, 4, 5, 6]) - @test choices[1 => :a] == 1.0 - @test choices[1 => :b] == 2.0 - @test choices[2 => :a] == 3.0 - @test choices[2 => :b] == 4.0 - @test choices[3 => :a] == 5.0 - @test choices[3 => :b] == 6.0 - @test length(collect(get_submaps_shallow(choices))) == 3 -end - @testset "dynamic assignment merge" begin submap = choicemap() set_value!(submap, :x, 1) @@ -107,7 +134,7 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @@ -125,8 +152,8 @@ end set_value!(submap, :x, 1) submap2 = choicemap() set_value!(submap2, :y, 4.) - choices1 = StaticChoiceMap((a=1., b=2.), (c=submap, shared=submap)) - choices2 = StaticChoiceMap((d=3.,), (e=submap, f=submap, shared=submap2)) + choices1 = StaticChoiceMap(a=1., b=2., c=submap, shared=submap) + choices2 = StaticChoiceMap(d=3., e=submap, f=submap, shared=submap2) choices = merge(choices1, choices2) @test choices[:a] == 1. @test choices[:b] == 2. @@ -136,124 +163,91 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @testset "static assignment variadic merge" begin - choices1 = StaticChoiceMap((a=1,), NamedTuple()) - choices2 = StaticChoiceMap((b=2,), NamedTuple()) - choices3 = StaticChoiceMap((c=3,), NamedTuple()) - choices_all = StaticChoiceMap((a=1, b=2, c=3), NamedTuple()) + choices1 = StaticChoiceMap(a=1) + choices2 = StaticChoiceMap(b=2) + choices3 = StaticChoiceMap(c=3) + choices_all = StaticChoiceMap(a=1, b=2, c=3) @test merge(choices1) == choices1 @test merge(choices1, choices2, choices3) == choices_all end +# TODO: in changing a lot of these to reflect the new behavior of choicemap, +# they are mostly not error checks, but instead checks for returning `EmptyChoiceMap`; +# should we relabel this testset? @testset "static assignment errors" begin + # get_choices on an address that returns a ValueChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x) == ValueChoiceMap(1) + + # static_get_submap on an address that contains a value returns a ValueChoiceMap + choices = StaticChoiceMap(x=1) + @test static_get_submap(choices, Val(:x)) == ValueChoiceMap(1) - # get_choices on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw - - # static_get_submap on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw - - # get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # get_submap on an address whose prefix contains a value returns EmptyChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) + choices = StaticChoiceMap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # static_get_choices on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # static_get_choices on an address that contains nothing returns an EmptyChoiceMap + choices = StaticChoiceMap() + @test static_get_submap(choices, Val(:x)) == EmptyChoiceMap() - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # static_get_value on an address that contains a submap throws a KeyError + # static_get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw - - # get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) + + # get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) + + # static_get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) end @testset "dynamic assignment errors" begin - - # get_choices on an address that contains a value throws a KeyError + # get_choices on an address that contains a value returns a ValueChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == ValueChoiceMap(1) - # get_choices on an address whose prefix contains a value throws a KeyError + # get_choices on an address whose prefix contains a value returns EmptyChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment choices = choicemap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError choices = choicemap() choices[:x => :y] = 1 - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # get_value on an address that contains nothing throws a KeyError + # get_value on an address that contains nothing throws a ChoiceMapGetValueError choices = choicemap() - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) end @testset "dynamic assignment overwrite" begin @@ -276,9 +270,7 @@ end choices = choicemap() choices[:x => :y] = 1 choices[:x] = 2 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == ValueChoiceMap(2) @test choices[:x] == 2 # overwrite subassignment with a subassignment @@ -293,17 +285,13 @@ end # illegal set value under existing value choices = choicemap() choices[:x] = 1 - threw = false - try set_value!(choices, :x => :y, 2) catch KeyError threw = true end - @test threw + @test_throws Exception set_value!(choices, :x => :y, 2) # illegal set submap under existing value choices = choicemap() choices[:x] = 1 submap = choicemap(); choices[:z] = 2 - threw = false - try set_submap!(choices, :x => :y, submap) catch KeyError threw = true end - @test threw + @test_throws Exception set_submap!(choices, :x => :y, submap) end @testset "dynamic assignment constructor" begin diff --git a/test/benchmark.md b/test/benchmark.md new file mode 100644 index 000000000..adabb8a58 --- /dev/null +++ b/test/benchmark.md @@ -0,0 +1,21 @@ +NEW version: +static choicemap nonnested lookup: + 0.728112 seconds (149.59 k allocations: 4.259 MiB) + 0.785652 seconds (100.00 k allocations: 1.526 MiB) + 0.693433 seconds (100.00 k allocations: 1.526 MiB) + 0.660211 seconds (100.00 k allocations: 1.526 MiB) +static choicemap nested lookup: + 0.680497 seconds (49.59 k allocations: 2.732 MiB) + 0.665768 seconds (1 allocation: 32 bytes) + 0.666708 seconds (1 allocation: 32 bytes) + 0.671009 seconds (1 allocation: 32 bytes) +static gen function choicemap nonnested lookup: + 0.701754 seconds (62.76 k allocations: 3.415 MiB) + 0.662916 seconds + 0.659019 seconds + 0.663398 seconds +static gen function choicemap nested lookup: + 1.338034 seconds (172.13 k allocations: 5.352 MiB) + 1.311123 seconds (100.00 k allocations: 1.526 MiB) + 1.311800 seconds (100.00 k allocations: 1.526 MiB) + 1.310289 seconds (100.00 k allocations: 1.526 MiB) \ No newline at end of file diff --git a/test/dynamic_dsl.jl b/test/dynamic_dsl.jl index 35f81703d..5561ae549 100644 --- a/test/dynamic_dsl.jl +++ b/test/dynamic_dsl.jl @@ -119,7 +119,7 @@ end @test get_value(discard, :x) == x @test get_value(discard, :u => :a) == a @test length(collect(get_values_shallow(discard))) == 2 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 # test new trace new_assignment = get_choices(new_trace) @@ -127,7 +127,7 @@ end @test get_value(new_assignment, :y) == y @test get_value(new_assignment, :v => :b) == b @test length(collect(get_values_shallow(new_assignment))) == 2 - @test length(collect(get_submaps_shallow(new_assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(new_assignment))) == 1 # test score and weight prev_score = ( @@ -242,7 +242,7 @@ end @test !isempty(get_submap(assignment, :v)) end @test length(collect(get_values_shallow(assignment))) == 2 - @test length(collect(get_submaps_shallow(assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(assignment))) == 1 # test weight if assignment[:branch] == prev_assignment[:branch] @@ -332,11 +332,11 @@ end @test get_value(choices, :out) == out @test get_value(choices, :bar => :z) == z @test !has_value(choices, :b) # was not selected - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(collect(get_values_shallow(choices))) == 2 # check gradient trie - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(gradients))) == 2 @test !has_value(gradients, :b) # was not selected @test isapprox(get_value(gradients, :bar => :z), @@ -431,14 +431,14 @@ end @test choices[:x => 2] == 2 @test choices[:x => 3 => :z] == 3 @test length(collect(get_values_shallow(choices))) == 1 # :y - @test length(collect(get_submaps_shallow(choices))) == 1 # :x + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # :x submap = get_submap(choices, :x) @test submap[1] == 1 @test submap[2] == 2 @test submap[3 => :z] == 3 @test length(collect(get_values_shallow(submap))) == 2 # 1, 2 - @test length(collect(get_submaps_shallow(submap))) == 1 # 3 + @test length(collect(get_nonvalue_submaps_shallow(submap))) == 1 # 3 bar_submap = get_submap(submap, 3) @test bar_submap[:z] == 3 diff --git a/test/modeling_library/call_at.jl b/test/modeling_library/call_at.jl index b27f0130d..607eb61fd 100644 --- a/test/modeling_library/call_at.jl +++ b/test/modeling_library/call_at.jl @@ -20,7 +20,7 @@ y = choices[3 => :y] @test isapprox(weight, logpdf(normal, y, 0.4, 1)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end @testset "generate" begin @@ -32,7 +32,7 @@ y = choices[3 => :y] @test get_retval(trace) == 0.4 + y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # with constraints y = 1.234 @@ -44,7 +44,7 @@ @test get_retval(trace) == 0.4 + y @test isapprox(weight, logpdf(normal, y, 0.4, 1.)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end function get_trace() @@ -71,7 +71,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isempty(discard) @@ -86,12 +86,12 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) # change kernel_args, different key, with constraint @@ -103,12 +103,12 @@ choices = get_choices(new_trace) @test choices[4 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) end @@ -121,7 +121,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isapprox(get_score(new_trace), logpdf(normal, y, 0.2, 1)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) y_new = choices[3 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -144,7 +144,7 @@ choices = get_choices(new_trace) y_new = choices[4 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -171,9 +171,9 @@ @test choices[3 => :y] == y @test isapprox(gradients[3 => :y], logpdf_grad(normal, y, 0.4, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 0 - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(input_grads) == 2 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.4, 1.0)[2] + retval_grad) @test input_grads[2] == nothing # the key has no gradient diff --git a/test/modeling_library/choice_at.jl b/test/modeling_library/choice_at.jl index 080b1b461..4f5241381 100644 --- a/test/modeling_library/choice_at.jl +++ b/test/modeling_library/choice_at.jl @@ -15,7 +15,7 @@ @test isapprox(weight, value ? log(0.4) : log(0.6)) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end @testset "generate" begin @@ -27,7 +27,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 # with constraints constraints = choicemap() @@ -39,7 +39,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end function get_trace() @@ -65,7 +65,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isempty(discard) @@ -78,12 +78,12 @@ choices = get_choices(new_trace) @test choices[3] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 # change kernel_args, different key, with constraint constraints = choicemap() @@ -93,12 +93,12 @@ choices = get_choices(new_trace) @test choices[4] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 end @testset "regenerate" begin @@ -110,7 +110,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isapprox(get_score(new_trace), log(0.2)) @@ -122,7 +122,7 @@ choices = get_choices(new_trace) value = choices[3] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) value = choices[4] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -163,9 +163,9 @@ @test choices[3] == y @test isapprox(gradients[3], logpdf_grad(normal, y, 0.0, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 1 - @test length(collect(get_submaps_shallow(gradients))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 0 @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test length(input_grads) == 3 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.0, 1.0)[2]) @test isapprox(input_grads[2], logpdf_grad(normal, y, 0.0, 1.0)[3]) diff --git a/test/modeling_library/recurse.jl b/test/modeling_library/recurse.jl index 46954e3be..b440a44fa 100644 --- a/test/modeling_library/recurse.jl +++ b/test/modeling_library/recurse.jl @@ -197,9 +197,9 @@ end @test choices[(4, Val(:production)) => :rule] == 4 @test choices[(4, Val(:aggregation)) => :prefix] == false @test discard[(3, Val(:aggregation)) => :prefix] == true - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 @test length(collect(get_values_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 1 @test retdiff == UnknownChange() diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index ba748453b..0f3a56180 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -28,7 +28,7 @@ x3 = trace[3 => :x] choices = get_choices(trace) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 expected_score = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x2, x1 * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -55,7 +55,7 @@ @test choices[1 => :x] == x1 @test choices[3 => :x] == x3 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x2 = choices[2 => :x] expected_weight = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -77,7 +77,7 @@ beta = 0.3 (choices, weight, retval) = propose(foo, (3, x_init, alpha, beta)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x1 = choices[1 => :x] x2 = choices[2 => :x] x3 = choices[3 => :x] diff --git a/test/optional_args.jl b/test/optional_args.jl index fd6c4ea71..b0fb821bd 100644 --- a/test/optional_args.jl +++ b/test/optional_args.jl @@ -1,4 +1,4 @@ -using Gen +#using Gen @testset "optional positional args (calling + GFI)" begin diff --git a/test/runtests.jl b/test/runtests.jl index a67a5f782..749236037 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,4 +74,4 @@ include("static_ir/static_ir.jl") include("static_dsl.jl") include("tilde_sugar.jl") include("inference/inference.jl") -include("modeling_library/modeling_library.jl") +include("modeling_library/modeling_library.jl") \ No newline at end of file diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 91c6c3202..9e2cecf3b 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -1,4 +1,6 @@ using Gen: generate_generative_function +using Test +using Gen @testset "static IR" begin @@ -362,12 +364,12 @@ end @test get_value(value_trie, :out) == out @test get_value(value_trie, :bar => :z) == z @test !has_value(value_trie, :b) # was not selected - @test length(get_submaps_shallow(value_trie)) == 1 - @test length(get_values_shallow(value_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(value_trie))) == 1 + @test length(collect(get_values_shallow(value_trie))) == 2 # check gradient trie - @test length(get_submaps_shallow(gradient_trie)) == 1 - @test length(get_values_shallow(gradient_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(gradient_trie))) == 1 + @test length(collect(get_values_shallow(gradient_trie))) == 2 @test !has_value(gradient_trie, :b) # was not selected @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) diff --git a/test/tilde_sugar.jl b/test/tilde_sugar.jl index fbd528b76..8396fe517 100644 --- a/test/tilde_sugar.jl +++ b/test/tilde_sugar.jl @@ -1,4 +1,4 @@ -using Gen +using .Gen import MacroTools normalize(ex) = MacroTools.prewalk(MacroTools.rmlines, ex) From 83349c7d4a320e028c9b24e26da4c3b44066fce9 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 23:13:03 -0400 Subject: [PATCH 05/45] performance improvements and benchmarking --- src/choice_map/choice_map.jl | 8 ++--- src/choice_map/static_choice_map.jl | 6 ++-- src/static_ir/trace.jl | 15 +++------ test/static_choicemap_benchmark.jl | 50 +++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 18 deletions(-) create mode 100644 test/static_choicemap_benchmark.jl diff --git a/src/choice_map/choice_map.jl b/src/choice_map/choice_map.jl index 402cefa37..213bc5f80 100644 --- a/src/choice_map/choice_map.jl +++ b/src/choice_map/choice_map.jl @@ -71,8 +71,8 @@ A syntactic sugar is `Base.getindex`: value = choices[addr] """ function get_value end -get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) -get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) +@inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) """ @@ -145,8 +145,8 @@ end @inline get_value(choices::ValueChoiceMap) = choices.val @inline get_submap(choices::ValueChoiceMap, addr) = EmptyChoiceMap() @inline get_submaps_shallow(choices::ValueChoiceMap) = () -Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val -Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) +@inline Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val +@inline Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) @inline get_address_schema(::Type{<:ValueChoiceMap}) = EmptyAddressSchema() """ diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index 1f75b3bca..58ef57d37 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -26,10 +26,10 @@ end quote EmptyChoiceMap() end end end -static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() +@inline static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() -static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) -static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) +@inline static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) +@inline static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) # convert a nonvalue choicemap all of whose top-level-addresses # are symbols into a staticchoicemap at the top level diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 5ac3ced16..168ccf50e 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -23,7 +23,7 @@ abstract type StaticIRTrace <: Trace end @inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false - Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) +@inline Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) @inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_getindex(trace, Val(addr)) @inline function Base.getindex(trace::StaticIRTrace, addr::Pair) @@ -31,6 +31,8 @@ abstract type StaticIRTrace <: Trace end return Gen.static_get_subtrace(trace, Val(first))[rest] end +@inline get_choices(trace::T) where {T <: StaticIRTrace} = StaticIRTraceAssmt{T}(trace) + const arg_prefix = gensym("arg") const choice_value_prefix = gensym("choice_value") const choice_score_prefix = gensym("choice_score") @@ -133,14 +135,6 @@ function generate_get_retval(ir::StaticIR, trace_struct_name::Symbol) Expr(:block, :(trace.$return_value_fieldname))) end -function generate_get_choices(trace_struct_name::Symbol) - Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_choices)), :(trace::$trace_struct_name)), - Expr(:if, :(!isempty(trace)), - :($(QuoteNode(StaticIRTraceAssmt))(trace)), - :($(QuoteNode(EmptyChoiceMap))()))) -end - function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] for node in ir.choice_nodes @@ -236,7 +230,6 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_score_expr = generate_get_score(trace_struct_name) get_args_expr = generate_get_args(ir, trace_struct_name) get_retval_expr = generate_get_retval(ir, trace_struct_name) - get_choices_expr = generate_get_choices(trace_struct_name) get_schema_expr = generate_get_schema(ir, trace_struct_name) get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name) static_get_submap_exprs = generate_static_get_submap(ir, trace_struct_name) @@ -244,7 +237,7 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr, get_args_expr, get_retval_expr, - get_choices_expr, get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) + get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) (exprs, trace_struct_name) end diff --git a/test/static_choicemap_benchmark.jl b/test/static_choicemap_benchmark.jl new file mode 100644 index 000000000..1e62b9a8e --- /dev/null +++ b/test/static_choicemap_benchmark.jl @@ -0,0 +1,50 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +scm = StaticChoiceMap(a=1, b=StaticChoiceMap(c=2)) + +println("static choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(scm) +end + +println("static choicemap nested lookup:") +for _=1:4 + @time many_nested(scm) +end + +@gen (static) function inner() + c ~ normal(0, 1) +end +@gen (static) function outer() + a ~ normal(0, 1) + b ~ inner() +end + +load_generated_functions() + +tr, _ = generate(outer, ()) +choices = get_choices(tr) + +println("static gen function choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(choices) +end + +println("static gen function choicemap nested lookup:") +for _=1:4 + @time many_nested(choices) +end From b9b5312e990fc49b08611b7077b7c6f3aa5d99ee Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 23:21:45 -0400 Subject: [PATCH 06/45] benchmark for dynamic choicemap lookups --- test/dynamic_choicemap_benchmark.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 test/dynamic_choicemap_benchmark.jl diff --git a/test/dynamic_choicemap_benchmark.jl b/test/dynamic_choicemap_benchmark.jl new file mode 100644 index 000000000..3724e44de --- /dev/null +++ b/test/dynamic_choicemap_benchmark.jl @@ -0,0 +1,27 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +cm = choicemap((:a, 1), (:b => :c, 2)) + +println("dynamic choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(cm) +end + +println("dynamic choicemap nested lookup:") +for _=1:4 + @time many_nested(cm) +end \ No newline at end of file From bce5e7724db64175bf2fd0f15fe25a4dc68af13e Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 23:30:40 -0400 Subject: [PATCH 07/45] inline dynamicchoicemap methods --- src/choice_map/dynamic_choice_map.jl | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/choice_map/dynamic_choice_map.jl b/src/choice_map/dynamic_choice_map.jl index a3403307c..0f27c89d7 100644 --- a/src/choice_map/dynamic_choice_map.jl +++ b/src/choice_map/dynamic_choice_map.jl @@ -67,16 +67,10 @@ function choicemap(tuples...) DynamicChoiceMap(tuples...) end -get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps -function get_submap(choices::DynamicChoiceMap, addr) - if haskey(choices.submaps, addr) - choices.submaps[addr] - else - EmptyChoiceMap() - end -end -get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) -Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) +@inline get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps +@inline get_submap(choices::DynamicChoiceMap, addr) = get(choices.submaps, addr, EmptyChoiceMap()) +@inline get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) +@inline Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) # mutation (not part of the assignment interface) From a985f9bd3dc3f8806e2da1e7c81fbe891334bac9 Mon Sep 17 00:00:00 2001 From: georgematheos Date: Tue, 19 May 2020 09:13:32 -0400 Subject: [PATCH 08/45] remove old version benchmark file --- test/benchmark.md | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 test/benchmark.md diff --git a/test/benchmark.md b/test/benchmark.md deleted file mode 100644 index adabb8a58..000000000 --- a/test/benchmark.md +++ /dev/null @@ -1,21 +0,0 @@ -NEW version: -static choicemap nonnested lookup: - 0.728112 seconds (149.59 k allocations: 4.259 MiB) - 0.785652 seconds (100.00 k allocations: 1.526 MiB) - 0.693433 seconds (100.00 k allocations: 1.526 MiB) - 0.660211 seconds (100.00 k allocations: 1.526 MiB) -static choicemap nested lookup: - 0.680497 seconds (49.59 k allocations: 2.732 MiB) - 0.665768 seconds (1 allocation: 32 bytes) - 0.666708 seconds (1 allocation: 32 bytes) - 0.671009 seconds (1 allocation: 32 bytes) -static gen function choicemap nonnested lookup: - 0.701754 seconds (62.76 k allocations: 3.415 MiB) - 0.662916 seconds - 0.659019 seconds - 0.663398 seconds -static gen function choicemap nested lookup: - 1.338034 seconds (172.13 k allocations: 5.352 MiB) - 1.311123 seconds (100.00 k allocations: 1.526 MiB) - 1.311800 seconds (100.00 k allocations: 1.526 MiB) - 1.310289 seconds (100.00 k allocations: 1.526 MiB) \ No newline at end of file From 1f5029cfc1637d4d3ac257cd46835312131c6ee2 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 09:15:55 -0400 Subject: [PATCH 09/45] minor testing cleanup --- test/optional_args.jl | 2 +- test/static_inference_benchmark.jl | 23 +++++++++++++++++++++++ test/static_ir/static_ir.jl | 2 -- test/tilde_sugar.jl | 2 +- 4 files changed, 25 insertions(+), 4 deletions(-) create mode 100644 test/static_inference_benchmark.jl diff --git a/test/optional_args.jl b/test/optional_args.jl index b0fb821bd..fd6c4ea71 100644 --- a/test/optional_args.jl +++ b/test/optional_args.jl @@ -1,4 +1,4 @@ -#using Gen +using Gen @testset "optional positional args (calling + GFI)" begin diff --git a/test/static_inference_benchmark.jl b/test/static_inference_benchmark.jl new file mode 100644 index 000000000..b70d08be2 --- /dev/null +++ b/test/static_inference_benchmark.jl @@ -0,0 +1,23 @@ +using Gen + +@gen (static, diffs) function foo() + a ~ normal(0, 1) + b ~ normal(a, 1) + c ~ normal(b, 1) +end + +@load_generated_functions + +observations = StaticChoiceMap(choicemap((:b,2), (:c,1.5))) +tr, _ = generate(foo, (), observations) + +function run_inference(trace) + tr = trace + for _=1:10^3 + tr, acc = mh(tr, select(:a)) + end +end + +for _=1:4 + @time run_inference(tr) +end \ No newline at end of file diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 9e2cecf3b..1b594d39d 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -1,6 +1,4 @@ using Gen: generate_generative_function -using Test -using Gen @testset "static IR" begin diff --git a/test/tilde_sugar.jl b/test/tilde_sugar.jl index 8396fe517..fbd528b76 100644 --- a/test/tilde_sugar.jl +++ b/test/tilde_sugar.jl @@ -1,4 +1,4 @@ -using .Gen +using Gen import MacroTools normalize(ex) = MacroTools.prewalk(MacroTools.rmlines, ex) From eb6adf7a76c5975fa20d7567a560175588aafed4 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:33:16 -0400 Subject: [PATCH 10/45] ensure valuechoicemap[] syntax works --- test/assignment.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/assignment.jl b/test/assignment.jl index 1d7e48a80..69890297f 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -5,6 +5,8 @@ @test vcm1 isa ValueChoiceMap{Int} @test vcm2 isa ValueChoiceMap{Float64} @test vcm3 isa ValueChoiceMap{Vector{Int}} + @test vcm1[] == 2 + @test vcm1[] == get_value(vcm1) @test !isempty(vcm1) @test has_value(vcm1) From eef941776857c50d8ad93ead2ee0d164d60f737e Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:43:49 -0400 Subject: [PATCH 11/45] provide some examples in the documentation --- docs/src/ref/choice_maps.md | 38 ++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index 6c445df6f..bf742f904 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -14,6 +14,42 @@ for sub-choicemaps. Leaf nodes have type: ValueChoiceMap ``` +### Example Usage Overview + +Choicemaps store values nested in a tree where each node posesses an address for each subtree. +A leaf-node choicemap simply contains a value, and has it's value looked up via: +```julia +value = choicemap[] +``` +If a choicemap has a value choicemap at address `:a`, it is looked up via: +```julia +value = choicemap[:a] +``` +And a choicemap may also have a non-value choicemap stored at a value. For instance, +if a choicemap has another choicemap stored at address `:a`, and this internal choicemap +has a valuechoicemap stored at address `:b` and another at `:c`, we could perform the following lookups: +```julia +value1 = choicemap[:a => :b] +value2 = choicemap[:a => :c] +``` +Nesting can be arbitrarily deep, and the keys can be arbitrary values; for instance +choicemaps can be constructed with values at the following nested addresses: +```julia +value = choicemap[:a => :b => :c => 4 => 1.63 => :e] +value = choicemap[:a => :b => :a => 2 => "alphabet" => :e] +``` +To get a sub-choicemap, use `get_submap`: +```julia +value1 = choicemap[:a => :b] +submap = get_submap(choicemap, :a) +value1 == submap[:b] # is true + +value_submap = get_submap(choicemap, :a => :b) +value_submap[] == value1 # is true +``` + +### Interface + Choice maps provide the following methods: ```@docs get_submap @@ -58,7 +94,7 @@ set_value! set_submap! ``` -## Implementing custom choicemap types +### Implementing custom choicemap types To implement a custom choicemap, one must implement `get_submap` and `get_submaps_shallow`. From a83adfbc2d02bed4e9c0a78163151c742cc660f8 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:50:25 -0400 Subject: [PATCH 12/45] fix some typos --- docs/src/ref/choice_maps.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index bf742f904..2963d304a 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -17,15 +17,15 @@ ValueChoiceMap ### Example Usage Overview Choicemaps store values nested in a tree where each node posesses an address for each subtree. -A leaf-node choicemap simply contains a value, and has it's value looked up via: +A leaf-node choicemap simply contains a value, and has its value looked up via: ```julia value = choicemap[] ``` -If a choicemap has a value choicemap at address `:a`, it is looked up via: +If a choicemap has a value choicemap at address `:a`, the value it stores is looked up via: ```julia value = choicemap[:a] ``` -And a choicemap may also have a non-value choicemap stored at a value. For instance, +A choicemap may also have a non-value choicemap stored at an address. For instance, if a choicemap has another choicemap stored at address `:a`, and this internal choicemap has a valuechoicemap stored at address `:b` and another at `:c`, we could perform the following lookups: ```julia From 1bd705f101bb7c783aedad30fe442f864bcec625 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:54:25 -0400 Subject: [PATCH 13/45] add phrase 'nesting level zero' to docs --- docs/src/ref/choice_maps.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index 2963d304a..4a23b7cfa 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -47,6 +47,8 @@ value1 == submap[:b] # is true value_submap = get_submap(choicemap, :a => :b) value_submap[] == value1 # is true ``` +One can think of `ValueChoiceMap`s at storing being a choicemap which has a value at "nesting level zero", +while other choicemaps have values at "nesting level" one or higher. ### Interface From 676828b0d16872f35ee2a327e2d445df9a449269 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 17 Jun 2020 15:29:17 -0400 Subject: [PATCH 14/45] distribution <: GenFn; dynamic DSL simplification --- src/Gen.jl | 9 +- src/distribution.jl | 123 ++++++++++++++++++++++ src/dynamic/assess.jl | 16 --- src/dynamic/backprop.jl | 4 +- src/dynamic/dynamic.jl | 6 -- src/dynamic/generate.jl | 32 ------ src/dynamic/project.jl | 14 +-- src/dynamic/propose.jl | 19 ---- src/dynamic/regenerate.jl | 58 ++--------- src/dynamic/simulate.jl | 18 ---- src/dynamic/trace.jl | 127 ++++------------------- src/dynamic/update.jl | 64 ++---------- src/modeling_library/modeling_library.jl | 48 --------- 13 files changed, 166 insertions(+), 372 deletions(-) create mode 100644 src/distribution.jl diff --git a/src/Gen.jl b/src/Gen.jl index fa2393596..fe3cdaa12 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -42,13 +42,16 @@ include("choice_map/choice_map.jl") # a homogeneous trie data type (not for use as choice map) include("trie.jl") +# built-in data types for arg-diff and ret-diff values +include("diff.jl") + # generative function interface include("gen_fn_interface.jl") -# built-in data types for arg-diff and ret-diff values -include("diff.jl") +# distribution abstract type +include("distribution.jl") -# built-in probability disributions +# built-in probability disributions; distribution dsl; combinators include("modeling_library/modeling_library.jl") # optimization of trainable parameters diff --git a/src/distribution.jl b/src/distribution.jl new file mode 100644 index 000000000..d72b21a61 --- /dev/null +++ b/src/distribution.jl @@ -0,0 +1,123 @@ +############################### +# Core Distribution Interface # +############################### + +struct DistributionTrace{T, Dist} <: Trace + val::T + args + dist::Dist +end + +abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end + +function Base.convert(::Type{DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} + DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, tr.dist) +end + +""" + val::T = random(dist::Distribution{T}, args...) + +Sample a random choice from the given distribution with the given arguments. +""" +function random end + +""" + lpdf = logpdf(dist::Distribution{T}, value::T, args...) + +Evaluate the log probability (density) of the value. +""" +function logpdf end + +""" + has::Bool = has_output_grad(dist::Distribution) + +Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice. +""" +function has_output_grad end + +""" + grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...) + +Compute the gradient of the logpdf with respect to the value, and each of the arguments. + +If `has_output_grad` returns false, then the first element of the returned tuple is `nothing`. +Otherwise, the first element of the tuple is the gradient with respect to the value. +If the return value of `has_argument_grads` has a false value for at position `i`, then the `i+1`th element of the returned tuple has value `nothing`. +Otherwise, this element contains the gradient with respect to the `i`th argument. +""" +function logpdf_grad end + +function is_discrete end + +# NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl + +get_return_type(::Distribution{T}) where {T} = T + + +############################## +# Distribution GFI Interface # +############################## + +@inline Base.getindex(trace::DistributionTrace) = trace.val +@inline Gen.get_args(trace::DistributionTrace) = trace.args +@inline Gen.get_choices(trace::DistributionTrace) = ValueChoiceMap(trace.val) # should be able to get type of val +@inline Gen.get_retval(trace::DistributionTrace) = trace.val +@inline Gen.get_gen_fn(trace::DistributionTrace) = trace.dist + +# TODO: for performance would it be better to store the score in the trace? +@inline Gen.get_score(trace::DistributionTrace) = logpdf(trace.dist, trace.val, trace.args...) +@inline Gen.project(trace::DistributionTrace, ::EmptySelection) = 0. +@inline Gen.project(trace::DistributionTrace, ::AllSelection) = get_score(trace) + +@inline function Gen.simulate(dist::Distribution, args::Tuple) + val = random(dist, args...) + DistributionTrace(val, args, dist) +end +@inline Gen.generate(dist::Distribution, args::Tuple, ::EmptyChoiceMap) = (simulate(dist, args), 0.) +@inline function Gen.generate(dist::Distribution, args::Tuple, constraints::ValueChoiceMap) + tr = DistributionTrace(get_value(constraints), args, dist) + weight = get_score(tr) + (tr, weight) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::ValueChoiceMap) + new_tr = DistributionTrace(get_value(constraints), args, tr.dist) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, UnknownChange(), get_choices(tr)) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::EmptyChoiceMap) + new_tr = DistributionTrace(tr.val, args, tr.dist) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, NoChange(), EmptyChoiceMap()) +end +# TODO: do I need an update method to handle empty choicemaps which are not `EmptyChoiceMap`s? +@inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, selection::EmptySelection) where {n} = (tr, 0., NoChange()) +@inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::EmptySelection) + new_tr = DistributionTrace(tr.val, args, tr.dist) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, NoChange()) +end +@inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::AllSelection) + new_val = random(tr.dist, args...) + new_tr = DistributionTrace(new_val, args, tr.dist) + (new_tr, 0., UnknownChange()) +end +@inline function Gen.propose(dist::Distribution, args::Tuple) + val = random(dist, args...) + score = logpdf(dist, val, args...) + (ValueChoiceMap(val), score, val) +end +@inline function Gen.assess(dist::Distribution, args::Tuple, choices::ValueChoiceMap) + weight = logpdf(dist, choices.val, args...) + (weight, choices.val) +end + +########### +# Exports # +########### + +export Distribution +export random +export logpdf +export logpdf_grad +export has_output_grad +export is_discrete diff --git a/src/dynamic/assess.jl b/src/dynamic/assess.jl index c583d5079..0bf37a077 100644 --- a/src/dynamic/assess.jl +++ b/src/dynamic/assess.jl @@ -9,22 +9,6 @@ function GFAssessState(choices, params::Dict{Symbol,Any}) GFAssessState(choices, 0., AddressVisitor(), params) end -function traceat(state::GFAssessState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # get return value - retval = get_value(state.choices, key) - - # update weight - state.weight += logpdf(dist, retval, args...) - - retval -end - function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local retval::T diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index 6a7278a02..e870d0c77 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -74,7 +74,7 @@ function traceat(state::GFBackpropParamsState, dist::Distribution{T}, args_maybe_tracked, key) where {T} local retval::T visit!(state.visitor, key) - retval = get_choice(state.trace, key).retval + retval = get_retval(get_call(state.trace, key).subtrace) args = map(value, args_maybe_tracked) score_tracked = track(logpdf(dist, retval, args...), state.tape) record!(state.tape, ReverseDiff.SpecialInstruction, dist, @@ -275,7 +275,7 @@ function traceat(state::GFBackpropTraceState, dist::Distribution{T}, args_maybe_tracked, key) where {T} local retval::T visit!(state.visitor, key) - retval = get_choice(state.trace, key).retval + retval = get_retval(get_call(state.trace, key).subtrace) args = map(value, args_maybe_tracked) score_tracked = track(logpdf(dist, retval, args...), state.tape) if key in state.selection diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index 73f22159a..0d8e03b4c 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -154,12 +154,6 @@ function check_is_empty(constraints::ChoiceMap, addr) end end -function check_no_value(constraints::ChoiceMap, addr) - if has_value(constraints, addr) - error("Expected a sub-assignment at address $addr but found a value") - end -end - function gen_fn_changed_error(addr) error("Generative function changed at address: $addr") end diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index 970dac42d..4a5796aae 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -11,38 +11,6 @@ function GFGenerateState(gen_fn, args, constraints, params) GFGenerateState(trace, constraints, 0., AddressVisitor(), params) end -function traceat(state::GFGenerateState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for constraints at this key - constrained = has_value(state.constraints, key) - !constrained && check_is_empty(state.constraints, key) - - # get return value - if constrained - retval = get_value(state.constraints, key) - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # add to the trace - add_choice!(state.trace, key, retval, score) - - # increment weight - if constrained - state.weight += score - end - - retval -end - function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local subtrace::U diff --git a/src/dynamic/project.jl b/src/dynamic/project.jl index 358398e55..81b45b3c6 100644 --- a/src/dynamic/project.jl +++ b/src/dynamic/project.jl @@ -1,15 +1,9 @@ -function project_recurse(trie::Trie{Any,ChoiceOrCallRecord}, +function project_recurse(trie::Trie{Any, CallRecord}, selection::Selection) weight = 0. - for (key, choice_or_call) in get_leaf_nodes(trie) - if choice_or_call.is_choice - if key in selection - weight += choice_or_call.score - end - else - subselection = selection[key] - weight += project(choice_or_call.subtrace_or_retval, subselection) - end + for (key, call) in get_leaf_nodes(trie) + subselection = selection[key] + weight += project(call.subtrace, subselection) end for (key, subtrie) in get_internal_nodes(trie) subselection = selection[key] diff --git a/src/dynamic/propose.jl b/src/dynamic/propose.jl index e4281f49e..32fc95da2 100644 --- a/src/dynamic/propose.jl +++ b/src/dynamic/propose.jl @@ -9,25 +9,6 @@ function GFProposeState(params::Dict{Symbol,Any}) GFProposeState(choicemap(), 0., AddressVisitor(), params) end -function traceat(state::GFProposeState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # sample return value - retval = random(dist, args...) - - # update assignment - set_value!(state.choices, key, retval) - - # update weight - state.weight += logpdf(dist, retval, args...) - - retval -end - function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local retval::T diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl index 81ba8b3c4..a4006a6c8 100644 --- a/src/dynamic/regenerate.jl +++ b/src/dynamic/regenerate.jl @@ -14,48 +14,6 @@ function GFRegenerateState(gen_fn, args, prev_trace, 0., visitor, params) end -function traceat(state::GFRegenerateState, dist::Distribution{T}, - args, key) where {T} - local prev_retval::T - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for previous choice at this key - has_previous = has_choice(state.prev_trace, key) - if has_previous - prev_choice = get_choice(state.prev_trace, key) - prev_retval = prev_choice.retval - prev_score = prev_choice.score - end - - # check whether the key was selected - in_selection = key in state.selection - - # get return value - if has_previous && in_selection - retval = random(dist, args...) - elseif has_previous - retval = prev_retval - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # update weight - if has_previous && !in_selection - state.weight += score - prev_score - end - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local prev_retval::T @@ -101,13 +59,11 @@ function splice(state::GFRegenerateState, gen_fn::DynamicDSLFunction, retval end -function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function regenerate_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::EmptySelection) noise = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - if !choice_or_call.is_choice - noise += choice_or_call.noise - end + for (key, call) in get_leaf_nodes(prev_trie) + noise += call.noise end for (key, subtrie) in get_internal_nodes(prev_trie) noise += regenerate_delete_recurse(subtrie, EmptySelection()) @@ -115,12 +71,12 @@ function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, noise end -function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function regenerate_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::DynamicSelection) noise = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - if !(key in visited) && !choice_or_call.is_choice - noise += choice_or_call.noise + for (key, call) in get_leaf_nodes(prev_trie) + if !(key in visited) + noise += call.noise end end for (key, subtrie) in get_internal_nodes(prev_trie) diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index 0addd8bfb..57f709dca 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -9,24 +9,6 @@ function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params) GFSimulateState(trace, AddressVisitor(), params) end -function traceat(state::GFSimulateState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - retval = random(dist, args...) - - # compute logpdf - score = logpdf(dist, retval, args...) - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local subtrace::U diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 882297e43..0a169d737 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -1,96 +1,37 @@ -struct ChoiceRecord{T} - retval::T - score::Float64 -end - struct CallRecord{T} subtrace::T score::Float64 noise::Float64 end -struct ChoiceOrCallRecord{T} - subtrace_or_retval::T - score::Float64 - noise::Float64 # if choice then NaN - is_choice::Bool -end - -function ChoiceRecord(record::ChoiceOrCallRecord) - if !record.is_choice - error("Found call but expected choice") - end - ChoiceRecord(record.subtrace_or_retval, record.score) -end - -function CallRecord(record::ChoiceOrCallRecord) - if record.is_choice - error("Found choice but expected call") - end - CallRecord(record.subtrace_or_retval, record.score, record.noise) -end - mutable struct DynamicDSLTrace{T} <: Trace gen_fn::T - trie::Trie{Any,ChoiceOrCallRecord} - isempty::Bool + trie::Trie{Any,CallRecord} score::Float64 noise::Float64 args::Tuple retval::Any function DynamicDSLTrace{T}(gen_fn::T, args) where {T} - trie = Trie{Any,ChoiceOrCallRecord}() + trie = Trie{Any,CallRecord}() # retval is not known yet - new(gen_fn, trie, true, 0, 0, args) + new(gen_fn, trie, 0, 0, args) end end set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval) -function has_choice(trace::DynamicDSLTrace, addr) - haskey(trace.trie, addr) && trace.trie[addr].is_choice -end - -function has_call(trace::DynamicDSLTrace, addr) - haskey(trace.trie, addr) && !trace.trie[addr].is_choice -end - -function get_choice(trace::DynamicDSLTrace, addr) - choice = trace.trie[addr] - if !choice.is_choice - throw(KeyError(addr)) - end - ChoiceRecord(choice) -end - -function get_call(trace::DynamicDSLTrace, addr) - call = trace.trie[addr] - if call.is_choice - throw(KeyError(addr)) - end - CallRecord(call) -end - -function add_choice!(trace::DynamicDSLTrace, addr, retval, score) - if haskey(trace.trie, addr) - error("Value or subtrace already present at address $addr. - The same address cannot be reused for multiple random choices.") - end - trace.trie[addr] = ChoiceOrCallRecord(retval, score, NaN, true) - trace.score += score - trace.isempty = false -end +has_call(trace::DynamicDSLTrace, addr) = haskey(trace.trie, addr) +get_call(trace::DynamicDSLTrace, addr) = trace.trie[addr] function add_call!(trace::DynamicDSLTrace, addr, subtrace) if haskey(trace.trie, addr) - error("Value or subtrace already present at address $addr. + error("Subtrace already present at address $addr. The same address cannot be reused for multiple random choices.") end score = get_score(subtrace) noise = project(subtrace, EmptySelection()) submap = get_choices(subtrace) - trace.isempty = trace.isempty && isempty(submap) - trace.trie[addr] = ChoiceOrCallRecord(subtrace, score, noise, false) + trace.trie[addr] = CallRecord(subtrace, score, noise) trace.score += score trace.noise += noise end @@ -106,47 +47,28 @@ get_gen_fn(trace::DynamicDSLTrace) = trace.gen_fn ## get_choices ## -function get_choices(trace::DynamicDSLTrace) - if !trace.isempty - DynamicDSLChoiceMap(trace.trie) # see below - else - EmptyChoiceMap() - end -end +get_choices(trace::DynamicDSLTrace) = DynamicDSLChoiceMap(trace.trie) struct DynamicDSLChoiceMap <: ChoiceMap - trie::Trie{Any,ChoiceOrCallRecord} + trie::Trie{Any,CallRecord} end get_address_schema(::Type{DynamicDSLChoiceMap}) = DynamicAddressSchema() get_submap(choices::DynamicDSLChoiceMap, addr::Pair) = _get_submap(choices, addr) - function get_submap(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - if has_leaf_node(trie, addr) - # leaf node, must be a call - call = trie[addr] - if call.is_choice - ValueChoiceMap(call.subtrace_or_retval) - else - get_choices(call.subtrace_or_retval) - end - elseif has_internal_node(trie, addr) - # internal node - subtrie = get_internal_node(trie, addr) - DynamicDSLChoiceMap(subtrie) # see below + if haskey(choices.trie.leaf_nodes, addr) + get_choices(choices.trie[addr].subtrace) + elseif haskey(choices.trie.internal_nodes, addr) + DynamicDSLChoiceMap(choices.trie.internal_nodes[addr]) else EmptyChoiceMap() end end function get_submaps_shallow(choices::DynamicDSLChoiceMap) - calls_iter = ( - (key, call.is_choice ? ValueChoiceMap(call.subtrace_or_retval) : get_choices(call.subtrace_or_retval)) - for (key, call) in get_leaf_nodes(choices.trie) - ) - internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) - Iterators.flatten((calls_iter, internal_nodes_iter)) + leafs = ((key, get_choices(record.subtrace)) for (key, record) in get_leaf_nodes(choices.trie)) + internals = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) + Iterators.flatten((leafs, internals)) end ## Base.getindex ## @@ -154,13 +76,7 @@ end function _getindex(trace::DynamicDSLTrace, trie::Trie, addr::Pair) (first, rest) = addr if haskey(trie.leaf_nodes, first) - choice_or_call = trie.leaf_nodes[first] - if choice_or_call.is_choice - error("Unknown address $addr; random choice at $first") - else - subtrace = choice_or_call.subtrace_or_retval - return subtrace[rest] - end + return trie.leaf_nodes[first].subtrace[rest] elseif haskey(trie.internal_nodes, first) return _getindex(trace, trie.internal_nodes[first], rest) else @@ -170,14 +86,7 @@ end function _getindex(trace::DynamicDSLTrace, trie::Trie, addr) if haskey(trie.leaf_nodes, addr) - choice_or_call = trie.leaf_nodes[addr] - if choice_or_call.is_choice - # the value of the random choice - return choice_or_call.subtrace_or_retval - else - # the return value of the generative function call - return get_retval(choice_or_call.subtrace_or_retval) - end + return get_retval(trie.leaf_nodes[addr].subtrace) else error("No random choice or generative function call at address $addr") end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 7acc16302..94f442acf 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -16,57 +16,6 @@ function GFUpdateState(gen_fn, args, prev_trace, constraints, params) 0., visitor, params, discard) end -function traceat(state::GFUpdateState, dist::Distribution{T}, - args::Tuple, key) where {T} - - local prev_retval::T - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for previous choice at this key - has_previous = has_choice(state.prev_trace, key) - if has_previous - prev_choice = get_choice(state.prev_trace, key) - prev_retval = prev_choice.retval - prev_score = prev_choice.score - end - - # check for constraints at this key - constrained = has_value(state.constraints, key) - !constrained && check_is_empty(state.constraints, key) - - # record the previous value as discarded if it is replaced - if constrained && has_previous - set_value!(state.discard, key, prev_retval) - end - - # get return value - if constrained - retval = get_value(state.constraints, key) - elseif has_previous - retval = prev_retval - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # update the weight - if has_previous - state.weight += score - prev_score - elseif constrained - state.weight += score - end - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, args::Tuple, key) where {T,U} @@ -78,7 +27,6 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, visit!(state.visitor, key) # check for constraints at this key - check_no_value(state.constraints, key) constraints = get_submap(state.constraints, key) # get subtrace @@ -119,11 +67,11 @@ function splice(state::GFUpdateState, gen_fn::DynamicDSLFunction, retval end -function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function update_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::EmptySelection) score = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - score += choice_or_call.score + for (key, call) in get_leaf_nodes(prev_trie) + score += call.score end for (key, subtrie) in get_internal_nodes(prev_trie) score += update_delete_recurse(subtrie, EmptySelection()) @@ -131,12 +79,12 @@ function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, score end -function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function update_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::DynamicSelection) score = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) + for (key, call) in get_leaf_nodes(prev_trie) if !(key in visited) - score += choice_or_call.score + score += call.score end end for (key, subtrie) in get_internal_nodes(prev_trie) diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index d0797426c..6200e5b52 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -5,54 +5,6 @@ import Distributions using SpecialFunctions: loggamma, logbeta, digamma -abstract type Distribution{T} end - -""" - val::T = random(dist::Distribution{T}, args...) - -Sample a random choice from the given distribution with the given arguments. -""" -function random end - -""" - lpdf = logpdf(dist::Distribution{T}, value::T, args...) - -Evaluate the log probability (density) of the value. -""" -function logpdf end - -""" - has::Bool = has_output_grad(dist::Distribution) - -Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice. -""" -function has_output_grad end - -""" - grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...) - -Compute the gradient of the logpdf with respect to the value, and each of the arguments. - -If `has_output_grad` returns false, then the first element of the returned tuple is `nothing`. -Otherwise, the first element of the tuple is the gradient with respect to the value. -If the return value of `has_argument_grads` has a false value for at position `i`, then the `i+1`th element of the returned tuple has value `nothing`. -Otherwise, this element contains the gradient with respect to the `i`th argument. -""" -function logpdf_grad end - -function is_discrete end - -# NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl - -get_return_type(::Distribution{T}) where {T} = T - -export Distribution -export random -export logpdf -export logpdf_grad -export has_output_grad -export is_discrete - # built-in distributions include("distributions/distributions.jl") From 5bf4207c6fb7ae23208b7559ba866b9fdb1bb717 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 17 Jun 2020 21:37:06 -0400 Subject: [PATCH 15/45] simplify static ir code --- src/address.jl | 3 +- src/choice_map/choice_map.jl | 2 +- src/distribution.jl | 3 + src/static_ir/backprop.jl | 230 +++++++++++++++++------------------ src/static_ir/dag.jl | 34 +----- src/static_ir/generate.jl | 21 ---- src/static_ir/project.jl | 8 -- src/static_ir/render_ir.jl | 11 +- src/static_ir/simulate.jl | 13 -- src/static_ir/trace.jl | 44 +------ src/static_ir/update.jl | 203 +++++-------------------------- test/runtests.jl | 12 +- test/static_dsl.jl | 36 +++--- 13 files changed, 186 insertions(+), 434 deletions(-) diff --git a/src/address.jl b/src/address.jl index 2d6499a6a..ad33cfe79 100644 --- a/src/address.jl +++ b/src/address.jl @@ -151,6 +151,7 @@ A hierarchical selection whose keys are among its type parameters. struct StaticSelection{T,U} <: HierarchicalSelection subselections::NamedTuple{T,U} end +StaticSelection(::NamedTuple{(), Tuple{}}) = EmptySelection() function Base.isempty(selection::StaticSelection{T,U}) where {T,U} length(R) == 0 && all(isempty(node) for node in selection.subselections) @@ -208,7 +209,7 @@ function StaticSelection(other::HierarchicalSelection) (keys, subselections) = ((), ()) end types = map(typeof, subselections) - StaticSelection{keys,Tuple{types...}}(NamedTuple{keys}(subselections)) + StaticSelection(NamedTuple{keys}(subselections)) end export StaticSelection diff --git a/src/choice_map/choice_map.jl b/src/choice_map/choice_map.jl index 213bc5f80..a1ca2eaef 100644 --- a/src/choice_map/choice_map.jl +++ b/src/choice_map/choice_map.jl @@ -147,7 +147,7 @@ end @inline get_submaps_shallow(choices::ValueChoiceMap) = () @inline Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val @inline Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) -@inline get_address_schema(::Type{<:ValueChoiceMap}) = EmptyAddressSchema() +@inline get_address_schema(::Type{<:ValueChoiceMap}) = AllAddressSchema() """ choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) diff --git a/src/distribution.jl b/src/distribution.jl index d72b21a61..a18803043 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -91,6 +91,9 @@ end end # TODO: do I need an update method to handle empty choicemaps which are not `EmptyChoiceMap`s? @inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, selection::EmptySelection) where {n} = (tr, 0., NoChange()) +# TODO: this next regenerate method is here because StaticSelections can have this sort of empty leaf node; choicemaps +# cannot right now and only have empty ones; we should try to fix this if possible. +#@inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, ::StaticSelection{(), Tuple{}}) where {n} = (tr, 0., NoChange()) @inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::EmptySelection) new_tr = DistributionTrace(tr.val, args, tr.dist) weight = get_score(new_tr) - get_score(tr) diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index b352d3ca2..eba97a82a 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -36,15 +36,15 @@ function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode) - if node in selected_choices - push!(fwd_marked, node) - end -end - function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode) - if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) - push!(fwd_marked, node) + if node.generative_function isa Distribution + if node in selected_choices + push!(fwd_marked, node) + end + else + if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) + push!(fwd_marked, node) + end end end @@ -60,20 +60,15 @@ function back_pass!(back_marked, node::JuliaNode) end end -function back_pass!(back_marked, node::RandomChoiceNode) - # the logpdf of every random choice is a SINK - for input_node in node.inputs - push!(back_marked, input_node) - end - # the value of every random choice is in back_marked, since it affects its logpdf - push!(back_marked, node) -end - function back_pass!(back_marked, node::GenerativeFunctionCallNode) # the logpdf of every generative function call is a SINK for input_node in node.inputs push!(back_marked, input_node) end + if node.generative_function isa Distribution + # the value of every random choice is in back_marked, since it affects its logpdf + push!(back_marked, node) + end end function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNode) @@ -134,35 +129,35 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - - # every random choice is in back_marked, since it affects it logpdf, but - # also possibly due to other downstream usage of the value - @assert node in back_marked +function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) + if node.generative_function isa Distribution + # for reference by other nodes during back_codegen! + # could performance optimize this away + push!(stmts, :($(node.name) = get_retval(trace.$(get_subtrace_fieldname(node))))) - if node in fwd_marked - # the only way we are fwd_marked is if this choice was selected + # every random choice is in back_marked, since it affects it logpdf, but + # also possibly due to other downstream usage of the value + @assert node in back_marked - # initialize gradient with respect to the value of the random choice to zero - # it will be a runtime error, thrown here, if there is no zero() method - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) - end -end + if node in fwd_marked + # the only way we are fwd_marked is if this choice was selected -function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - subtrace_fieldname = get_subtrace_fieldname(node) - push!(stmts, :($(node.name) = get_retval(trace.$subtrace_fieldname))) + # initialize gradient with respect to the value of the random choice to zero + # it will be a runtime error, thrown here, if there is no zero() method + push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + end + else + # for reference by other nodes during back_codegen! + # could performance optimize this away + subtrace_fieldname = get_subtrace_fieldname(node) + push!(stmts, :($(node.name) = get_retval(trace.$subtrace_fieldname))) - # NOTE: we will still potentially run choice_gradients recursively on the generative function, - # we just might not use its return value gradient. - if node in fwd_marked && node in back_marked - # we are fwd_marked if an input was fwd_marked, or if we were selected internally - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + # NOTE: we will still potentially run choice_gradients recursively on the generative function, + # we just might not use its return value gradient. + if node in fwd_marked && node in back_marked + # we are fwd_marked if an input was fwd_marked, or if we were selected internally + push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + end end end @@ -217,19 +212,19 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: end function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, - node::RandomChoiceNode, logpdf_grad::Symbol) + node::GenerativeFunctionCallNode, logpdf_grad::Symbol) # only evaluate the gradient of the logpdf if we need to if any(input_node in fwd_marked for input_node in node.inputs) || node in fwd_marked args = map((input_node) -> input_node.name, node.inputs) - push!(stmts, :($logpdf_grad = logpdf_grad($(node.dist), $(node.name), $(args...)))) + push!(stmts, :($logpdf_grad = logpdf_grad($(node.generative_function), $(node.name), $(args...)))) end # increment gradients of input nodes that are in fwd_marked for (i, input_node) in enumerate(node.inputs) if input_node in fwd_marked @assert input_node in back_marked # this ensured its gradient will have been initialized - if !has_argument_grads(node.dist)[i] - error("Distribution $(node.dist) does not have logpdf gradient for argument $i") + if !has_argument_grads(node.generative_function)[i] + error("Distribution $(node.generative_function) does not have logpdf gradient for argument $i") end push!(stmts, :($(gradient_var(input_node)) += $logpdf_grad[$(QuoteNode(i+1))])) end @@ -243,94 +238,91 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke end function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropTraceMode) - logpdf_grad = gensym("logpdf_grad") + node::GenerativeFunctionCallNode, mode::BackpropTraceMode) + if node.generative_function isa Distribution + logpdf_grad = gensym("logpdf_grad") - # backpropagate to the inputs - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) + # backpropagate to the inputs + back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) - # backpropagate to the value (if it was selected) - if node in fwd_marked - if !has_output_grad(node.dist) - error("Distribution $dist does not logpdf gradient for its output value") + # backpropagate to the value (if it was selected) + if node in fwd_marked + if !has_output_grad(node.generative_function) + error("Distribution $(node.generative_function) does not logpdf gradient for its output value") + end + push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) + end + else + # handle case when it is the return node + if node === ir.return_node && node in fwd_marked + @assert node in back_marked + push!(stmts, :($(gradient_var(node)) += retval_grad)) end - push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) - end -end - -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropParamsMode) - logpdf_grad = gensym("logpdf_grad") - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) -end - -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::GenerativeFunctionCallNode, mode::BackpropTraceMode) - - # handle case when it is the return node - if node === ir.return_node && node in fwd_marked - @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) - end - if node in fwd_marked - input_grads = gensym("call_input_grads") - value_trie = value_trie_var(node) - gradient_trie = gradient_trie_var(node) - subtrace_fieldname = get_subtrace_fieldname(node) - call_selection = gensym("call_selection") - if node in selected_calls - push!(stmts, :($call_selection = $qn_static_getindex(selection, $(QuoteNode(Val(node.addr)))))) - else - push!(stmts, :($call_selection = EmptySelection())) + if node in fwd_marked + input_grads = gensym("call_input_grads") + value_trie = value_trie_var(node) + gradient_trie = gradient_trie_var(node) + subtrace_fieldname = get_subtrace_fieldname(node) + call_selection = gensym("call_selection") + if node in selected_calls + push!(stmts, :($call_selection = $qn_static_getindex(selection, $(QuoteNode(Val(node.addr)))))) + else + push!(stmts, :($call_selection = EmptySelection())) + end + retval_grad = node in back_marked ? gradient_var(node) : :(nothing) + push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( + trace.$subtrace_fieldname, $call_selection, $retval_grad))) end - retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( - trace.$subtrace_fieldname, $call_selection, $retval_grad))) - end - # increment gradients of input nodes that are in fwd_marked - for (i, input_node) in enumerate(node.inputs) - if input_node in fwd_marked - @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + # increment gradients of input nodes that are in fwd_marked + for (i, input_node) in enumerate(node.inputs) + if input_node in fwd_marked + @assert input_node in back_marked # this ensured its gradient will have been initialized + push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + end end - end - # NOTE: the value_trie and gradient_trie are dealt with later + # NOTE: the value_trie and gradient_trie are dealt with later + end end function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropParamsMode) - # handle case when it is the return node - if node === ir.return_node && node in fwd_marked - @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) - end + if node.generative_function isa Distribution + logpdf_grad = gensym("logpdf_grad") + back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) + else + # handle case when it is the return node + if node === ir.return_node && node in fwd_marked + @assert node in back_marked + push!(stmts, :($(gradient_var(node)) += retval_grad)) + end - if node in fwd_marked - input_grads = gensym("call_input_grads") - subtrace_fieldname = get_subtrace_fieldname(node) - retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :($input_grads = accumulate_param_gradients!(trace.$subtrace_fieldname, $retval_grad))) - end + if node in fwd_marked + input_grads = gensym("call_input_grads") + subtrace_fieldname = get_subtrace_fieldname(node) + retval_grad = node in back_marked ? gradient_var(node) : :(nothing) + push!(stmts, :($input_grads = accumulate_param_gradients!(trace.$subtrace_fieldname, $retval_grad))) + end - # increment gradients of input nodes that are in fwd_marked - for (i, input_node) in enumerate(node.inputs) - if input_node in fwd_marked - @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + # increment gradients of input nodes that are in fwd_marked + for (i, input_node) in enumerate(node.inputs) + if input_node in fwd_marked + @assert input_node in back_marked # this ensured its gradient will have been initialized + push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + end end end end -function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, +function generate_value_gradient_trie(selected_choices::Set{GenerativeFunctionCallNode}, selected_calls::Set{GenerativeFunctionCallNode}, value_trie::Symbol, gradient_trie::Symbol) selected_choices_vec = collect(selected_choices) quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) - leaf_value_choicemaps = map((node) -> :(ValueChoiceMap(trace.$(get_value_fieldname(node)))), selected_choices_vec) + leaf_value_choicemaps = map((node) -> :(ValueChoiceMap(get_retval(trace.$(get_subtrace_fieldname(node))))), selected_choices_vec) leaf_gradient_choicemaps = map((node) -> :(ValueChoiceMap($(gradient_var(node)))), selected_choices_vec) selected_calls_vec = collect(selected_calls) @@ -350,18 +342,18 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, end function get_selected_choices(::EmptyAddressSchema, ::StaticIR) - Set{RandomChoiceNode}() + Set{GenerativeFunctionCallNode}() end function get_selected_choices(::AllAddressSchema, ir::StaticIR) - Set{RandomChoiceNodes}(ir.choice_nodes) + Set{GenerativeFunctionCallNode}([node for node in ir.call_nodes if node.generative_function isa Distribution]...) end function get_selected_choices(schema::StaticAddressSchema, ir::StaticIR) selected_choice_addrs = Set(keys(schema)) - selected_choices = Set{RandomChoiceNode}() - for node in ir.choice_nodes - if node.addr in selected_choice_addrs + selected_choices = Set{GenerativeFunctionCallNode}() + for node in ir.call_nodes + if node.generative_function isa Distribution && node.addr in selected_choice_addrs push!(selected_choices, node) end end @@ -373,14 +365,14 @@ function get_selected_calls(::EmptyAddressSchema, ::StaticIR) end function get_selected_calls(::AllAddressSchema, ir::StaticIR) - Set{GenerativeFunctionCallNode}(ir.call_nodes) + Set{GenerativeFunctionCallNode}([node for node in ir.call_nodes if !(node.generative_function isa Distribution)]...) end function get_selected_calls(schema::StaticAddressSchema, ir::StaticIR) selected_call_addrs = Set(keys(schema)) selected_calls = Set{GenerativeFunctionCallNode}() for node in ir.call_nodes - if node.addr in selected_call_addrs + if !(node.generative_function isa Distribution) && node.addr in selected_call_addrs push!(selected_calls, node) end end @@ -452,7 +444,7 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, ir = get_ir(gen_fn_type) # unlike choice_gradients we don't take gradients w.r.t. the value of random choices - selected_choices = Set{RandomChoiceNode}() + selected_choices = Set{GenerativeFunctionCallNode}() # we need to guarantee that we visit every generative function call, # because we need to backpropagate to its trainable parameters diff --git a/src/static_ir/dag.jl b/src/static_ir/dag.jl index c82658892..de6acf6e7 100644 --- a/src/static_ir/dag.jl +++ b/src/static_ir/dag.jl @@ -18,14 +18,6 @@ struct JuliaNode <: StaticIRNode typ::Union{Symbol,Expr,QuoteNode} end -struct RandomChoiceNode <: StaticIRNode - dist::Distribution - inputs::Vector{StaticIRNode} - addr::Symbol - name::Symbol - typ::Union{Symbol,Expr,QuoteNode} -end - struct GenerativeFunctionCallNode <: StaticIRNode generative_function::GenerativeFunction inputs::Vector{StaticIRNode} @@ -38,7 +30,6 @@ struct StaticIR nodes::Vector{StaticIRNode} trainable_param_nodes::Vector{TrainableParameterNode} arg_nodes::Vector{ArgumentNode} - choice_nodes::Vector{RandomChoiceNode} call_nodes::Vector{GenerativeFunctionCallNode} julia_nodes::Vector{JuliaNode} return_node::StaticIRNode @@ -50,12 +41,10 @@ mutable struct StaticIRBuilder node_set::Set{StaticIRNode} trainable_param_nodes::Vector{TrainableParameterNode} arg_nodes::Vector{ArgumentNode} - choice_nodes::Vector{RandomChoiceNode} call_nodes::Vector{GenerativeFunctionCallNode} julia_nodes::Vector{JuliaNode} return_node::Union{Nothing,StaticIRNode} vars::Set{Symbol} - addrs_to_choice_nodes::Dict{Symbol,RandomChoiceNode} addrs_to_call_nodes::Dict{Symbol,GenerativeFunctionCallNode} accepts_output_grad::Bool end @@ -65,17 +54,15 @@ function StaticIRBuilder() node_set = Set{StaticIRNode}() trainable_param_nodes = Vector{TrainableParameterNode}() arg_nodes = Vector{ArgumentNode}() - choice_nodes = Vector{RandomChoiceNode}() call_nodes = Vector{GenerativeFunctionCallNode}() julia_nodes = Vector{JuliaNode}() return_node = nothing vars = Set{Symbol}() - addrs_to_choice_nodes = Dict{Symbol,RandomChoiceNode}() addrs_to_call_nodes = Dict{Symbol,GenerativeFunctionCallNode}() accepts_output_grad = false - StaticIRBuilder(nodes, node_set, trainable_param_nodes, arg_nodes, choice_nodes, call_nodes, + StaticIRBuilder(nodes, node_set, trainable_param_nodes, arg_nodes, call_nodes, julia_nodes, - return_node, vars, addrs_to_choice_nodes, addrs_to_call_nodes, + return_node, vars, addrs_to_call_nodes, accepts_output_grad) end @@ -87,7 +74,6 @@ function build_ir(builder::StaticIRBuilder) builder.nodes, builder.trainable_param_nodes, builder.arg_nodes, - builder.choice_nodes, builder.call_nodes, builder.julia_nodes, builder.return_node, @@ -109,7 +95,7 @@ function check_inputs_exist(builder::StaticIRBuilder, input_nodes) end function check_addr_unique(builder::StaticIRBuilder, addr::Symbol) - if haskey(builder.addrs_to_choice_nodes, addr) || haskey(builder.addrs_to_call_nodes, addr) + if haskey(builder.addrs_to_call_nodes, addr) error("Address $addr was not unique") end end @@ -164,20 +150,6 @@ function add_constant_node!(builder::StaticIRBuilder, val, node end -function add_addr_node!(builder::StaticIRBuilder, dist::Distribution; - inputs::Vector=[], addr::Symbol=gensym(), - name::Symbol=gensym()) - check_unique_var(builder, name) - check_addr_unique(builder, addr) - check_inputs_exist(builder, inputs) - typ = QuoteNode(get_return_type(dist)) - node = RandomChoiceNode(dist, inputs, addr, name, typ) - _add_node!(builder, node) - builder.addrs_to_choice_nodes[addr] = node - push!(builder.choice_nodes, node) - node -end - function add_addr_node!(builder::StaticIRBuilder, gen_fn::GenerativeFunction; inputs::Vector=[], addr::Symbol=gensym(), name::Symbol=gensym()) diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index b53eae95d..2beecca54 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -21,27 +21,6 @@ function process!(state::StaticIRGenerateState, node::JuliaNode, options) end end -function process!(state::StaticIRGenerateState, node::RandomChoiceNode, options) - schema = state.schema - args = map((input_node) -> input_node.name, node.inputs) - incr = gensym("logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) - if isa(schema, StaticAddressSchema) && (node.addr in keys(schema)) - push!(state.stmts, :($(node.name) = $qn_static_get_value(constraints, Val($addr)))) - push!(state.stmts, :($incr = $qn_logpdf($dist, $(node.name), $(args...)))) - push!(state.stmts, :($weight += $incr)) - else - push!(state.stmts, :($(node.name) = $qn_random($dist, $(args...)))) - push!(state.stmts, :($incr = $qn_logpdf($dist, $(node.name), $(args...)))) - end - push!(state.stmts, :($(get_value_fieldname(node)) = $(node.name))) - push!(state.stmts, :($(get_score_fieldname(node)) = $incr)) - push!(state.stmts, :($num_nonempty_fieldname += 1)) - push!(state.stmts, :($total_score_fieldname += $incr)) -end - function process!(state::StaticIRGenerateState, node::GenerativeFunctionCallNode, options) schema = state.schema args = map((input_node) -> input_node.name, node.inputs) diff --git a/src/static_ir/project.jl b/src/static_ir/project.jl index 2f65ecc2c..ed14a2f29 100644 --- a/src/static_ir/project.jl +++ b/src/static_ir/project.jl @@ -5,14 +5,6 @@ end function process!(state::StaticIRProjectState, node) end -function process!(state::StaticIRProjectState, node::RandomChoiceNode) - schema = state.schema - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) - if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.stmts, :($weight += trace.$(get_score_fieldname(node)))) - end -end - function process!(state::StaticIRProjectState, node::GenerativeFunctionCallNode) schema = state.schema addr = QuoteNode(node.addr) diff --git a/src/static_ir/render_ir.jl b/src/static_ir/render_ir.jl index 22e7b3625..880fec505 100644 --- a/src/static_ir/render_ir.jl +++ b/src/static_ir/render_ir.jl @@ -1,7 +1,12 @@ label(node::ArgumentNode) = String(node.name) label(node::JuliaNode) = String(node.name) -label(node::RandomChoiceNode) = "$(node.dist) $(node.addr) $(node.name)" -label(node::GenerativeFunctionCallNode) = "$(node.addr) $(node.name)" +function label(node::GenerativeFunctionCallNode) + if node.generative_function isa Distribution + "$(node.generative_function) $(node.addr) $(node.name)" + else + "$(node.addr) $(node.name)" + end +end function draw_graph(ir::StaticIR, graphviz, fname) dot = graphviz.Digraph() @@ -14,7 +19,7 @@ function draw_graph(ir::StaticIR, graphviz, fname) shape = "diamond" color = "white" parents = [] - elseif isa(node, RandomChoiceNode) + elseif isa(node, GenerativeFunctionCallNode) && node.generative_function isa Distribution shape = "ellipse" color = "white" parents = node.inputs diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index b2d5429e2..267183ac1 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -20,19 +20,6 @@ function process!(state::StaticIRSimulateState, node::JuliaNode, options) end end -function process!(state::StaticIRSimulateState, node::RandomChoiceNode, options) - args = map((input_node) -> input_node.name, node.inputs) - incr = gensym("logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - push!(state.stmts, :($(node.name) = $qn_random($dist, $(args...)))) - push!(state.stmts, :($incr = $qn_logpdf($dist, $(node.name), $(args...)))) - push!(state.stmts, :($(get_value_fieldname(node)) = $(node.name))) - push!(state.stmts, :($(get_score_fieldname(node)) = $incr)) - push!(state.stmts, :($num_nonempty_fieldname += 1)) - push!(state.stmts, :($total_score_fieldname += $incr)) -end - function process!(state::StaticIRSimulateState, node::GenerativeFunctionCallNode, options) args = map((input_node) -> input_node.name, node.inputs) args_tuple = Expr(:tuple, args...) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 168ccf50e..358a01d76 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -43,18 +43,10 @@ function get_value_fieldname(node::ArgumentNode) Symbol("$(arg_prefix)_$(node.name)") end -function get_value_fieldname(node::RandomChoiceNode) - Symbol("$(choice_value_prefix)_$(node.addr)") -end - function get_value_fieldname(node::JuliaNode) Symbol("$(julia_prefix)_$(node.name)") end -function get_score_fieldname(node::RandomChoiceNode) - Symbol("$(choice_score_prefix)_$(node.addr)") -end - function get_subtrace_fieldname(node::GenerativeFunctionCallNode) Symbol("$(subtrace_prefix)_$(node.addr)") end @@ -75,12 +67,6 @@ function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptio fieldname = get_value_fieldname(node) push!(fields, TraceField(fieldname, node.typ)) end - for node in ir.choice_nodes - value_fieldname = get_value_fieldname(node) - push!(fields, TraceField(value_fieldname, node.typ)) - score_fieldname = get_score_fieldname(node) - push!(fields, TraceField(score_fieldname, QuoteNode(Float64))) - end for node in ir.call_nodes subtrace_fieldname = get_subtrace_fieldname(node) subtrace_type = QuoteNode(get_trace_type(node.generative_function)) @@ -137,11 +123,6 @@ end function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] - for node in ir.choice_nodes - addr = node.addr - value = :(choices.trace.$(get_value_fieldname(node))) - push!(elements, :(($(QuoteNode(addr)), ValueChoiceMap($value)))) - end for node in ir.call_nodes addr = node.addr subtrace = :(choices.trace.$(get_subtrace_fieldname(node))) @@ -175,19 +156,8 @@ function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) end ) end - - choice_getindex_exprs = Expr[] - for node in ir.choice_nodes - push!(choice_getindex_exprs, - quote - function Gen.static_getindex(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) - return trace.$(get_value_fieldname(node)) - end - end - ) - end - return [get_subtrace_exprs; call_getindex_exprs; choice_getindex_exprs] + return [get_subtrace_exprs; call_getindex_exprs] end function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) @@ -201,21 +171,11 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) :(get_choices(choices.trace.$(get_subtrace_fieldname(node))))))) end - # return a ValueChoiceMap if get_submap is run on an address containing a value - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_submap)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(ValueChoiceMap(choices.trace.$(get_value_fieldname(node))))))) - end methods end function generate_get_schema(ir::StaticIR, trace_struct_name::Symbol) - choice_addrs = [QuoteNode(node.addr) for node in ir.choice_nodes] - call_addrs = [QuoteNode(node.addr) for node in ir.call_nodes] - addrs = vcat(choice_addrs, call_addrs) + addrs = [QuoteNode(node.addr) for node in ir.call_nodes] Expr(:function, Expr(:call, Expr(:(.), Gen, QuoteNode(:get_schema)), :(::Type{$trace_struct_name})), Expr(:block, diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index c806bba3a..03f6d8c83 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -9,7 +9,6 @@ const calldiff_prefix = gensym("calldiff") calldiff_var(node::GenerativeFunctionCallNode) = Symbol("$(calldiff_prefix)_$(node.addr)") const choice_discard_prefix = gensym("choice_discard") -choice_discard_var(node::RandomChoiceNode) = Symbol("$(choice_discard_prefix)_$(node.addr)") const call_discard_prefix = gensym("call_discard") call_discard_var(node::GenerativeFunctionCallNode) = Symbol("$(call_discard_prefix)_$(node.addr)") @@ -19,21 +18,18 @@ call_discard_var(node::GenerativeFunctionCallNode) = Symbol("$(call_discard_pref ######################## struct ForwardPassState - input_changed::Set{Union{RandomChoiceNode,GenerativeFunctionCallNode}} + input_changed::Set{GenerativeFunctionCallNode} value_changed::Set{StaticIRNode} - constrained_or_selected_choices::Set{RandomChoiceNode} constrained_or_selected_calls::Set{GenerativeFunctionCallNode} discard_calls::Set{GenerativeFunctionCallNode} end function ForwardPassState() - input_changed = Set{Union{RandomChoiceNode,GenerativeFunctionCallNode}}() + input_changed = Set{GenerativeFunctionCallNode}() value_changed = Set{StaticIRNode}() - constrained_or_selected_choices = Set{RandomChoiceNode}() constrained_or_selected_calls = Set{GenerativeFunctionCallNode}() discard_calls = Set{GenerativeFunctionCallNode}() - ForwardPassState(input_changed, value_changed, constrained_or_selected_choices, - constrained_or_selected_calls, discard_calls) + ForwardPassState(input_changed, value_changed, constrained_or_selected_calls, discard_calls) end function forward_pass_argdiff!(state::ForwardPassState, @@ -46,30 +42,19 @@ function forward_pass_argdiff!(state::ForwardPassState, end end -function process_forward!(::AddressSchema, ::ForwardPassState, ::TrainableParameterNode) end +function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, ::ForwardPassState, ::TrainableParameterNode) end -function process_forward!(::AddressSchema, ::ForwardPassState, node::ArgumentNode) end +function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, ::ForwardPassState, node::ArgumentNode) end -function process_forward!(::AddressSchema, state::ForwardPassState, node::JuliaNode) +function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, state::ForwardPassState, node::JuliaNode) if any(input_node in state.value_changed for input_node in node.inputs) push!(state.value_changed, node) end end -function process_forward!(schema::AddressSchema, state::ForwardPassState, - node::RandomChoiceNode) - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) - if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.constrained_or_selected_choices, node) - push!(state.value_changed, node) - end - if any(input_node in state.value_changed for input_node in node.inputs) - push!(state.input_changed, node) - end -end - -function process_forward!(schema::AddressSchema, state::ForwardPassState, +function process_forward!(constraint_type::Type{<:Union{<:ChoiceMap, Selection}}, state::ForwardPassState, node::GenerativeFunctionCallNode) + schema = get_address_schema(constraint_type) @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) push!(state.constrained_or_selected_calls, node) @@ -78,8 +63,20 @@ function process_forward!(schema::AddressSchema, state::ForwardPassState, end if any(input_node in state.value_changed for input_node in node.inputs) push!(state.input_changed, node) - push!(state.value_changed, node) # TODO can check whether the node is satically absorbing push!(state.discard_calls, node) + + ## check if we can statically guarantee that this generative function has a `NoChange` diff ## + update_fn = constraint_type <: ChoiceMap ? Gen.update : Gen.regenerate + + trace_type = get_trace_type(node.generative_function) + update_rettype = Core.Compiler.return_type( + update_fn, + Tuple{trace_type, Tuple, Tuple, constraint_type} + ) + guaranteed_returns_nochange = update_rettype <: Tuple && update_rettype != Union{} && update_rettype.parameters[3] == NoChange + if !guaranteed_returns_nochange + push!(state.value_changed, node) + end end end @@ -113,15 +110,6 @@ function process_backward!(fwd::ForwardPassState, back::BackwardPassState, end end -function process_backward!(fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, options) - if node in fwd.input_changed || node in fwd.constrained_or_selected_choices - for input_node in node.inputs - push!(back.marked, input_node) - end - end -end - function process_backward!(fwd::ForwardPassState, back::BackwardPassState, node::GenerativeFunctionCallNode, options) if node in fwd.input_changed || node in fwd.constrained_or_selected_calls @@ -189,118 +177,6 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, end end -function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, ::UpdateMode, - options) - if options.track_diffs - - # track diffs - arg_values, _ = arg_values_and_diffs_from_tracked_diffs(node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - push!(stmts, :($(node.name) = $qn_Diffed($qn_static_get_value(constraints, Val($addr)), $qn_unknown_change))) - push!(stmts, :($(choice_discard_var(node)) = trace.$(get_value_fieldname(node)))) - else - push!(stmts, :($(node.name) = $qn_Diffed(trace.$(get_value_fieldname(node)), NoChange()))) - end - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $qn_strip_diff($(node.name)), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = $qn_Diffed(trace.$(get_value_fieldname(node)), $qn_no_change))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - - else - - # no track diffs - arg_values = map((n) -> n.name, node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - push!(stmts, :($(node.name) = $qn_static_get_value(constraints, Val($addr)))) - push!(stmts, :($(choice_discard_var(node)) = trace.$(get_value_fieldname(node)))) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - end - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $(node.name), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - end -end - -function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, ::RegenerateMode, - options) - if options.track_diffs - - # track diffs - arg_values, _ = arg_values_and_diffs_from_tracked_diffs(node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - output_value = Expr(:call, qn_strip_diff, node.name) - if node in fwd.constrained_or_selected_choices - # the choice was selected, it does not contribute to the weight - push!(stmts, :($(node.name) = $qn_Diffed($qn_random($dist, $(arg_values...)), UnknownChange()))) - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $output_value, $(arg_values...)))) - else - # the choice was not selected, and the input to the choice changed - # it does contribute to the weight - push!(stmts, :($(node.name) = $qn_Diffed(trace.$(get_value_fieldname(node)), NoChange()))) - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $output_value, $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = $qn_Diffed(trace.$(get_value_fieldname(node)), NoChange()))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - else - - # no track diffs - arg_values = map((n) -> n.name, node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - # the choice was selected, it does not contribute to the weight - push!(stmts, :($(node.name) = $qn_random($dist, $(arg_values...)))) - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $(node.name), $(arg_values...)))) - else - # the choice was not selected, and the input to the choice changed - # it does contribute to the weight - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $(node.name), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - end -end - function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, node::GenerativeFunctionCallNode, ::UpdateMode, options) @@ -431,33 +307,20 @@ function generate_new_trace!(stmts::Vector{Expr}, trace_type::Type, options) end end -function generate_discard!(stmts::Vector{Expr}, - constrained_choices::Set{RandomChoiceNode}, - discard_calls::Set{GenerativeFunctionCallNode}) - discard_leaf_nodes = Dict{Symbol,Symbol}() - for node in constrained_choices - discard_leaf_nodes[node.addr] = choice_discard_var(node) - end - discard_internal_nodes = Dict{Symbol,Symbol}() +function generate_discard!(stmts::Vector{Expr}, discard_calls::Set{GenerativeFunctionCallNode}) + discard_nodes = Dict{Symbol,Symbol}() for node in discard_calls - discard_internal_nodes[node.addr] = call_discard_var(node) - end - if length(discard_leaf_nodes) > 0 - (leaf_keys, leaf_nodes) = collect(zip(discard_leaf_nodes...)) - else - (leaf_keys, leaf_nodes) = ((), ()) + discard_nodes[node.addr] = call_discard_var(node) end - if length(discard_internal_nodes) > 0 - (internal_keys, internal_nodes) = collect(zip(discard_internal_nodes...)) + + if length(discard_nodes) > 0 + (keys, nodes) = collect(zip(discard_nodes...)) else - (internal_keys, internal_nodes) = ((), ()) + (keys, nodes) = ((), ()) end - leaf_keys = map((key::Symbol) -> QuoteNode(key), leaf_keys) - internal_keys = map((key::Symbol) -> QuoteNode(key), internal_keys) - all_keys = (leaf_keys..., internal_keys...) - all_nodes = ([:($(QuoteNode(ValueChoiceMap))($node)) for node in leaf_nodes]..., internal_nodes...) + keys = map((key::Symbol) -> QuoteNode(key), keys) expr = quote $(QuoteNode(StaticChoiceMap))( - $(QuoteNode(NamedTuple)){($(all_keys...),)}(($(all_nodes...),))) end + $(QuoteNode(NamedTuple)){($(keys...),)}(($(nodes...),))) end push!(stmts, :($discard = $expr)) end @@ -482,7 +345,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ fwd_state = ForwardPassState() forward_pass_argdiff!(fwd_state, ir.arg_nodes, argdiffs_type) for node in ir.nodes - process_forward!(schema, fwd_state, node) + process_forward!(constraints_type, fwd_state, node) end # backward marking pass @@ -505,7 +368,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ end generate_return_value!(stmts, fwd_state, ir.return_node, options) generate_new_trace!(stmts, trace_type, options) - generate_discard!(stmts, fwd_state.constrained_or_selected_choices, fwd_state.discard_calls) + generate_discard!(stmts, fwd_state.discard_calls) # return trace and weight and discard and retdiff push!(stmts, :(return ($trace, $weight, $retdiff, $discard))) @@ -530,7 +393,7 @@ function codegen_regenerate(trace_type::Type{T}, args_type::Type, argdiffs_type: fwd_state = ForwardPassState() forward_pass_argdiff!(fwd_state, ir.arg_nodes, argdiffs_type) for node in ir.nodes - process_forward!(schema, fwd_state, node) + process_forward!(selection_type, fwd_state, node) end # backward marking pass diff --git a/test/runtests.jl b/test/runtests.jl index 749236037..5ad3a2d13 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,12 +64,12 @@ end const dx = 1e-6 -include("autodiff.jl") -include("diff.jl") -include("selection.jl") -include("assignment.jl") -include("dynamic_dsl.jl") -include("optional_args.jl") +# include("autodiff.jl") +# include("diff.jl") +# include("selection.jl") +# include("assignment.jl") +# include("dynamic_dsl.jl") +# include("optional_args.jl") include("static_ir/static_ir.jl") include("static_dsl.jl") include("tilde_sugar.jl") diff --git a/test/static_dsl.jl b/test/static_dsl.jl index 317a8427d..5576aa7df 100644 --- a/test/static_dsl.jl +++ b/test/static_dsl.jl @@ -117,14 +117,13 @@ params = ir.arg_nodes[2] @test params.compute_grad # choice nodes and call nodes -@test length(ir.choice_nodes) == 2 -@test length(ir.call_nodes) == 0 +@test length(ir.call_nodes) == 2 # is_outlier -is_outlier = ir.choice_nodes[1] +is_outlier = ir.call_nodes[1] @test is_outlier.addr == :z @test is_outlier.typ == QuoteNode(Bool) -@test is_outlier.dist == bernoulli +@test is_outlier.generative_function == bernoulli @test length(is_outlier.inputs) == 1 # std @@ -138,10 +137,10 @@ in2 = std.inputs[2] @test (in1 === is_outlier && in2 === params) || (in2 === is_outlier && in1 === params) # y -y = ir.choice_nodes[2] +y = ir.call_nodes[2] @test y.addr == :y @test y.typ == QuoteNode(Float64) -@test y.dist == normal +@test y.generative_function == normal @test length(y.inputs) == 2 @test y.inputs[2] === std @@ -174,40 +173,39 @@ xs = ir.arg_nodes[1] @test xs.typ == :(Vector{Float64}) @test !xs.compute_grad -# choice nodes and call nodes -@test length(ir.choice_nodes) == 4 -@test length(ir.call_nodes) == 1 +# call nodes +@test length(ir.call_nodes) == 5 # inlier_std -inlier_std = ir.choice_nodes[1] +inlier_std = ir.call_nodes[1] @test inlier_std.addr == :inlier_std @test inlier_std.typ == QuoteNode(Float64) -@test inlier_std.dist == gamma +@test inlier_std.generative_function == gamma @test length(inlier_std.inputs) == 2 # outlier_std -outlier_std = ir.choice_nodes[2] +outlier_std = ir.call_nodes[2] @test outlier_std.addr == :outlier_std @test outlier_std.typ == QuoteNode(Float64) -@test outlier_std.dist == gamma +@test outlier_std.generative_function == gamma @test length(outlier_std.inputs) == 2 # slope -slope = ir.choice_nodes[3] +slope = ir.call_nodes[3] @test slope.addr == :slope @test slope.typ == QuoteNode(Float64) -@test slope.dist == normal +@test slope.generative_function == normal @test length(slope.inputs) == 2 # intercept -intercept = ir.choice_nodes[4] +intercept = ir.call_nodes[4] @test intercept.addr == :intercept @test intercept.typ == QuoteNode(Float64) -@test intercept.dist == normal +@test intercept.generative_function == normal @test length(intercept.inputs) == 2 # data -ys = ir.call_nodes[1] +ys = ir.call_nodes[5] @test ys.addr == :data @test ys.typ == QuoteNode(PersistentVector{Float64}) @test ys.generative_function == data_fn @@ -376,7 +374,7 @@ ir2 = Gen.get_ir(typeof(f2)) return_node1 = ir1.return_node return_node2 = ir2.return_node @test isa(return_node2, typeof(return_node1)) -@test return_node2.dist == return_node1.dist +@test return_node2.generative_function == return_node1.generative_function inputs1 = return_node1.inputs inputs2 = return_node2.inputs From 61673a46cef1af989b003e226e762fc21c1814eb Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 17 Jun 2020 21:59:39 -0400 Subject: [PATCH 16/45] brief documentation for Dist <: GenFn --- docs/src/ref/distributions.md | 6 ++++++ docs/src/ref/extending.md | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/src/ref/distributions.md b/docs/src/ref/distributions.md index c81801d43..f4b6d7688 100644 --- a/docs/src/ref/distributions.md +++ b/docs/src/ref/distributions.md @@ -1,5 +1,11 @@ # Probability Distributions +In Gen, a probability distribution is a generative function which makes a single random choice +and returns the value of this choice. The choicemap for a probability distribution +is always a [`ValueChoiceMap`](@ref). In addition to supporting the regular `GFI` methods, +every distribution supports the methods [`random`](@ref) and [`logpdf`](@ref), described +in the [Distribution API](@ref custom_distributions). + Gen provides a library of built-in probability distributions, and two ways of writing custom distributions, both of which are explained below: diff --git a/docs/src/ref/extending.md b/docs/src/ref/extending.md index 7f9dfd480..b1d759b3a 100644 --- a/docs/src/ref/extending.md +++ b/docs/src/ref/extending.md @@ -110,7 +110,7 @@ Gen's Distribution interface directly, as defined below. Probability distributions are singleton types whose supertype is `Distribution{T}`, where `T` indicates the data type of the random sample. ```julia -abstract type Distribution{T} end +abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace} end ``` A new Distribution type must implement the following methods: @@ -146,6 +146,9 @@ has_output_grad logpdf_grad ``` +Any custom distribution will automatically be a `GenerativeFunction` since `Distribution <: GenerativeFunction`; +implementations of all GFI methods are automatically provided in terms of `random` and `logpdf`. + ## Custom generative functions We recommend the following steps for implementing a new type of generative function, and also looking at the implementation for the [`DynamicDSLFunction`](@ref) type as an example. From 298a333fc4a7646a8f5ec64c00174453a55fff56 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 17 Jun 2020 21:59:50 -0400 Subject: [PATCH 17/45] short map over distribution test --- test/assignment.jl | 4 ++-- test/modeling_library/map.jl | 15 +++++++++++++++ test/runtests.jl | 12 ++++++------ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/test/assignment.jl b/test/assignment.jl index 69890297f..7485f92c7 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -27,8 +27,8 @@ @test !has_value(vcm1, :addr) @test isapprox(vcm2, ValueChoiceMap(prevfloat(2.))) @test isapprox(vcm1, ValueChoiceMap(prevfloat(2.))) - @test get_address_schema(typeof(vcm1)) == EmptyAddressSchema() - @test get_address_schema(ValueChoiceMap) == EmptyAddressSchema() + @test get_address_schema(typeof(vcm1)) == AllAddressSchema() + @test get_address_schema(ValueChoiceMap) == AllAddressSchema() @test nested_view(vcm1) == 2 end diff --git a/test/modeling_library/map.jl b/test/modeling_library/map.jl index bfb13eb4e..ffe07d77f 100644 --- a/test/modeling_library/map.jl +++ b/test/modeling_library/map.jl @@ -402,4 +402,19 @@ @test isapprox(get_param_grad(foo, :std), expected_std_grad) end + @testset "map over distribution" begin + flip_coins = Map(bernoulli) + coinflips_tr, weight = generate(flip_coins, (fill(0.4, 100),)) + @test weight == 0. + @test coinflips_tr[20] isa Bool + choices = get_choices(coinflips_tr) + @test get_submap(choices, 42) isa ValueChoiceMap{Bool} + val42 = get_value(choices, 42) + new_tr, weight, retdiff, discard = update(coinflips_tr, (fill(0.4, 100),), (NoChange(),), choicemap((42, !val42))) + @test new_tr[42] == !val42 + expected_score_change = logpdf(bernoulli, !val42, 0.4) - logpdf(bernoulli, val42, 0.4) + @test isapprox(get_score(new_tr) - get_score(coinflips_tr), expected_score_change) + @test isapprox(weight, expected_score_change) + end + end diff --git a/test/runtests.jl b/test/runtests.jl index 5ad3a2d13..749236037 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,12 +64,12 @@ end const dx = 1e-6 -# include("autodiff.jl") -# include("diff.jl") -# include("selection.jl") -# include("assignment.jl") -# include("dynamic_dsl.jl") -# include("optional_args.jl") +include("autodiff.jl") +include("diff.jl") +include("selection.jl") +include("assignment.jl") +include("dynamic_dsl.jl") +include("optional_args.jl") include("static_ir/static_ir.jl") include("static_dsl.jl") include("tilde_sugar.jl") From e34875a84a9b31f1f799b40ef582ee28308d4ffe Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 09:26:33 -0400 Subject: [PATCH 18/45] default static_get_submap = EmptyChoiceMap --- src/static_ir/trace.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 358a01d76..4ebe28009 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -20,6 +20,7 @@ function get_schema end abstract type StaticIRTrace <: Trace end @inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_submap(trace::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() @inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false From 972d4555907813ec7fd77a2b202fdfdacf4d5f79 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 09:29:06 -0400 Subject: [PATCH 19/45] default static_get_submap = EmptyChoiceMap --- src/static_ir/trace.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 168ccf50e..a79ed539b 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -20,6 +20,7 @@ function get_schema end abstract type StaticIRTrace <: Trace end @inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() @inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false From ee64d12fad1b25d645d642ec2422b8dc1a62ae6f Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 09:54:13 -0400 Subject: [PATCH 20/45] dist performance improvements --- src/distribution.jl | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/distribution.jl b/src/distribution.jl index a18803043..878fb7ad2 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -5,13 +5,15 @@ struct DistributionTrace{T, Dist} <: Trace val::T args - dist::Dist + score::Float64 end +@inline dist(::DistributionTrace{T, Dist}) where {T, Dist} = Dist() abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end +@inline DistributionTrace(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...)) function Base.convert(::Type{DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} - DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, tr.dist) + DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, dist(tr)) end """ @@ -62,10 +64,8 @@ get_return_type(::Distribution{T}) where {T} = T @inline Gen.get_args(trace::DistributionTrace) = trace.args @inline Gen.get_choices(trace::DistributionTrace) = ValueChoiceMap(trace.val) # should be able to get type of val @inline Gen.get_retval(trace::DistributionTrace) = trace.val -@inline Gen.get_gen_fn(trace::DistributionTrace) = trace.dist - -# TODO: for performance would it be better to store the score in the trace? -@inline Gen.get_score(trace::DistributionTrace) = logpdf(trace.dist, trace.val, trace.args...) +@inline Gen.get_gen_fn(trace::DistributionTrace) = dist(trace) +@inline Gen.get_score(trace::DistributionTrace) = trace.score @inline Gen.project(trace::DistributionTrace, ::EmptySelection) = 0. @inline Gen.project(trace::DistributionTrace, ::AllSelection) = get_score(trace) @@ -80,28 +80,26 @@ end (tr, weight) end @inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::ValueChoiceMap) - new_tr = DistributionTrace(get_value(constraints), args, tr.dist) + new_tr = DistributionTrace(get_value(constraints), args, dist(tr)) weight = get_score(new_tr) - get_score(tr) (new_tr, weight, UnknownChange(), get_choices(tr)) end @inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::EmptyChoiceMap) - new_tr = DistributionTrace(tr.val, args, tr.dist) + new_tr = DistributionTrace(tr.val, args, dist(tr)) weight = get_score(new_tr) - get_score(tr) (new_tr, weight, NoChange(), EmptyChoiceMap()) end +@inline Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, constraints::EmptyChoiceMap) where {n} = (tr, 0., NoChange()) # TODO: do I need an update method to handle empty choicemaps which are not `EmptyChoiceMap`s? @inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, selection::EmptySelection) where {n} = (tr, 0., NoChange()) -# TODO: this next regenerate method is here because StaticSelections can have this sort of empty leaf node; choicemaps -# cannot right now and only have empty ones; we should try to fix this if possible. -#@inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, ::StaticSelection{(), Tuple{}}) where {n} = (tr, 0., NoChange()) @inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::EmptySelection) - new_tr = DistributionTrace(tr.val, args, tr.dist) + new_tr = DistributionTrace(tr.val, args, dist(tr)) weight = get_score(new_tr) - get_score(tr) (new_tr, weight, NoChange()) end @inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::AllSelection) - new_val = random(tr.dist, args...) - new_tr = DistributionTrace(new_val, args, tr.dist) + new_val = random(dist(tr), args...) + new_tr = DistributionTrace(new_val, args, dist(tr)) (new_tr, 0., UnknownChange()) end @inline function Gen.propose(dist::Distribution, args::Tuple) @@ -110,7 +108,7 @@ end (ValueChoiceMap(val), score, val) end @inline function Gen.assess(dist::Distribution, args::Tuple, choices::ValueChoiceMap) - weight = logpdf(dist, choices.val, args...) + weight = logpdf(dist, get_value(choices), args...) (weight, choices.val) end From fd1991ff3df029224ddc464642abb0ec15c5ead3 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 10:53:28 -0400 Subject: [PATCH 21/45] minor performance improvement --- src/choice_map/static_choice_map.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index 58ef57d37..1aa40d4f3 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -36,7 +36,7 @@ end function StaticChoiceMap(other::ChoiceMap) keys_and_nodes = collect(get_submaps_shallow(other)) if length(keys_and_nodes) > 0 - (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + (addrs::NTuple{n, Symbol} where {n}, submaps) = zip(keys_and_nodes...) else addrs = () submaps = () From c3d5db029e57d7bcb381a113dca1fa3659983296 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 12:35:10 -0400 Subject: [PATCH 22/45] performance improvement related to zip bug --- src/choice_map/static_choice_map.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index 1aa40d4f3..ff8c01a7e 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -34,9 +34,10 @@ end # convert a nonvalue choicemap all of whose top-level-addresses # are symbols into a staticchoicemap at the top level function StaticChoiceMap(other::ChoiceMap) - keys_and_nodes = collect(get_submaps_shallow(other)) + keys_and_nodes = get_submaps_shallow(other) if length(keys_and_nodes) > 0 - (addrs::NTuple{n, Symbol} where {n}, submaps) = zip(keys_and_nodes...) + addrs = Tuple(key for (key, _) in keys_and_nodes) + submaps = Tuple(submap for (_, submap) in keys_and_nodes) else addrs = () submaps = () From 8a43845bb1e1925125e28ad9131fd5621a9f9d5d Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sat, 20 Jun 2020 10:12:00 -0400 Subject: [PATCH 23/45] better static retdiff checking --- src/static_ir/update.jl | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index 03f6d8c83..e3baf39ab 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -52,29 +52,39 @@ function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, state::ForwardP end end -function process_forward!(constraint_type::Type{<:Union{<:ChoiceMap, Selection}}, state::ForwardPassState, +function cannot_statically_guarantee_nochange_retdiff(constraint_type, node, state) + update_fn = constraint_type <: ChoiceMap ? Gen.update : Gen.regenerate + + trace_type = get_trace_type(node.generative_function) + argdiff_types = map(input_node -> input_node in state.value_changed ? UnknownChange : NoChange, node.inputs) + argdiff_type = Tuple{argdiff_types...} + # TODO: can we know the arg type statically? + update_rettype = Core.Compiler.return_type( + update_fn, + Tuple{trace_type, Tuple, argdiff_type, constraint_type} + ) + has_static_retdiff = update_rettype <: Tuple && update_rettype != Union{} && length(update_rettype.parameters) > 3 + guaranteed_returns_nochange = has_static_retdiff && update_rettype.parameters[3] == NoChange + + return !guaranteed_returns_nochange +end + +function process_forward!(constraint_type::Type{<:Union{<:ChoiceMap, <:Selection}}, state::ForwardPassState, node::GenerativeFunctionCallNode) schema = get_address_schema(constraint_type) + will_run_update = false @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) push!(state.constrained_or_selected_calls, node) - push!(state.value_changed, node) - push!(state.discard_calls, node) + will_run_update = true end if any(input_node in state.value_changed for input_node in node.inputs) push!(state.input_changed, node) + will_run_update = true + end + if will_run_update push!(state.discard_calls, node) - - ## check if we can statically guarantee that this generative function has a `NoChange` diff ## - update_fn = constraint_type <: ChoiceMap ? Gen.update : Gen.regenerate - - trace_type = get_trace_type(node.generative_function) - update_rettype = Core.Compiler.return_type( - update_fn, - Tuple{trace_type, Tuple, Tuple, constraint_type} - ) - guaranteed_returns_nochange = update_rettype <: Tuple && update_rettype != Union{} && update_rettype.parameters[3] == NoChange - if !guaranteed_returns_nochange + if cannot_statically_guarantee_nochange_retdiff(constraint_type, node, state) push!(state.value_changed, node) end end From ffd9373c243593fbf46ef9b4331a93b9a787fdfd Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 10:39:43 -0400 Subject: [PATCH 24/45] add static info for dist trace type --- src/distribution.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/distribution.jl b/src/distribution.jl index 878fb7ad2..7dd5aa11b 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -12,6 +12,9 @@ end abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end @inline DistributionTrace(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...)) +# we need to know the specific distribution in the trace type so the compiler can specialize GFI calls fully +@inline get_trace_type(::Dist) where {T, Dist <: Distribution{T}} = DistributionTrace{T, Dist} + function Base.convert(::Type{DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, dist(tr)) end From 67d5e120c07e7a0b32a658e45698cc674b304334 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 15:09:06 -0400 Subject: [PATCH 25/45] don't use static get_submap for staticchoicemap --- src/choice_map/static_choice_map.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index ff8c01a7e..587fc6ee5 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -17,7 +17,15 @@ end @inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) @inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) -@inline get_submap(choices::StaticChoiceMap, addr::Symbol) = static_get_submap(choices, Val(addr)) + +# TODO: would it be faster to do static_get_submap? +function get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, addr::Symbol) where {Addrs, SubmapTypes} + if addr in Addrs + choices.submaps[addr] + else + EmptyChoiceMap() + end +end @generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} if A in Addrs From 4966ea9ee0f10f633d854f5c0bef3d17adaf712e Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 15:27:59 -0400 Subject: [PATCH 26/45] some simple MH benchmarks --- test/benchmarks/dynamic_mh.jl | 77 +++++++++++++++++++++++++++ test/benchmarks/run_benchmarks.jl | 2 + test/benchmarks/static_mh.jl | 87 +++++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+) create mode 100644 test/benchmarks/dynamic_mh.jl create mode 100644 test/benchmarks/run_benchmarks.jl create mode 100644 test/benchmarks/static_mh.jl diff --git a/test/benchmarks/dynamic_mh.jl b/test/benchmarks/dynamic_mh.jl new file mode 100644 index 000000000..fb7661a7d --- /dev/null +++ b/test/benchmarks/dynamic_mh.jl @@ -0,0 +1,77 @@ +module DynamicMHBenchmark +using Gen +import Random + +include("../../examples/regression/dynamic_model.jl") +include("../../examples/regression/dataset.jl") + +@gen function slope_proposal(trace) + slope = trace[:slope] + @trace(normal(slope, 0.5), :slope) +end + +@gen function intercept_proposal(trace) + intercept = trace[:intercept] + @trace(normal(intercept, 0.5), :intercept) +end + +@gen function inlier_std_proposal(trace) + log_inlier_std = trace[:log_inlier_std] + @trace(normal(log_inlier_std, 0.5), :log_inlier_std) +end + +@gen function outlier_std_proposal(trace) + log_outlier_std = trace[:log_outlier_std] + @trace(normal(log_outlier_std, 0.5), :log_outlier_std) +end + +@gen function is_outlier_proposal(trace, i::Int) + prev = trace[:data => i => :z] + @trace(bernoulli(prev ? 0.0 : 1.0), :data => i => :z) +end + +function do_inference(xs, ys, num_iters) + observations = choicemap() + for (i, y) in enumerate(ys) + observations[:data => i => :y] = y + end + + # initial trace + (trace, _) = generate(model, (xs,), observations) + + scores = Vector{Float64}(undef, num_iters) + for i=1:num_iters + + # steps on the parameters + for j=1:5 + (trace, _) = metropolis_hastings(trace, slope_proposal, ()) + (trace, _) = metropolis_hastings(trace, intercept_proposal, ()) + (trace, _) = metropolis_hastings(trace, inlier_std_proposal, ()) + (trace, _) = metropolis_hastings(trace, outlier_std_proposal, ()) + end + + # step on the outliers + for j=1:length(xs) + (trace, _) = metropolis_hastings(trace, is_outlier_proposal, (j,)) + end + + score = get_score(trace) + scores[i] = score + + # print + slope = trace[:slope] + intercept = trace[:intercept] + inlier_std = exp(trace[:log_inlier_std]) + outlier_std = exp(trace[:log_outlier_std]) + end + return scores +end + +println("Simple dynamic DSL MH on regression model:") +(xs, ys) = make_data_set(200) +do_inference(xs, ys, 10) +@time do_inference(xs, ys, 50) +@time do_inference(xs, ys, 50) +println() + +end \ No newline at end of file diff --git a/test/benchmarks/run_benchmarks.jl b/test/benchmarks/run_benchmarks.jl new file mode 100644 index 000000000..13a4754e8 --- /dev/null +++ b/test/benchmarks/run_benchmarks.jl @@ -0,0 +1,2 @@ +include("static_mh.jl") +include("dynamic_mh.jl") \ No newline at end of file diff --git a/test/benchmarks/static_mh.jl b/test/benchmarks/static_mh.jl new file mode 100644 index 000000000..0e801631c --- /dev/null +++ b/test/benchmarks/static_mh.jl @@ -0,0 +1,87 @@ +module StaticMHBenchmark +using Gen +import Random + +include("../../examples/regression/static_model.jl") +include("../../examples/regression/dataset.jl") + +@gen (static) function slope_proposal(trace) + slope = trace[:slope] + @trace(normal(slope, 0.5), :slope) +end + +@gen (static) function intercept_proposal(trace) + intercept = trace[:intercept] + @trace(normal(intercept, 0.5), :intercept) +end + +@gen (static) function inlier_std_proposal(trace) + log_inlier_std = trace[:log_inlier_std] + @trace(normal(log_inlier_std, 0.5), :log_inlier_std) +end + +@gen (static) function outlier_std_proposal(trace) + log_outlier_std = trace[:log_outlier_std] + @trace(normal(log_outlier_std, 0.5), :log_outlier_std) +end + +@gen (static) function flip_z(z::Bool) + @trace(bernoulli(z ? 0.0 : 1.0), :z) +end + +@gen (static) function is_outlier_proposal(trace, i::Int) + prev_z = trace[:data => i => :z] + @trace(bernoulli(prev_z ? 0.0 : 1.0), :data => i => :z) +end + +@gen (static) function is_outlier_proposal(trace, i::Int) + prev_z = trace[:data => i => :z] + @trace(bernoulli(prev_z ? 0.0 : 1.0), :data => i => :z) +end + +Gen.load_generated_functions() + +function do_inference(xs, ys, num_iters) + observations = choicemap() + for (i, y) in enumerate(ys) + observations[:data => i => :y] = y + end + + # initial trace + (trace, _) = generate(model, (xs,), observations) + + scores = Vector{Float64}(undef, num_iters) + for i=1:num_iters + + # steps on the parameters + for j=1:5 + (trace, _) = metropolis_hastings(trace, slope_proposal, ()) + (trace, _) = metropolis_hastings(trace, intercept_proposal, ()) + (trace, _) = metropolis_hastings(trace, inlier_std_proposal, ()) + (trace, _) = metropolis_hastings(trace, outlier_std_proposal, ()) + end + + # step on the outliers + for j=1:length(xs) + (trace, _) = metropolis_hastings(trace, is_outlier_proposal, (j,)) + end + + score = get_score(trace) + scores[i] = score + + # print + slope = trace[:slope] + intercept = trace[:intercept] + inlier_std = exp(trace[:log_inlier_std]) + outlier_std = exp(trace[:log_outlier_std]) + end + return scores +end + +(xs, ys) = make_data_set(200) +do_inference(xs, ys, 10) +println("Simple static DSL (including CallAt nodes) MH on regression model:") +@time do_inference(xs, ys, 50) +@time do_inference(xs, ys, 50) +println() +end \ No newline at end of file From 0909a5b3eb9e99829dd3bfe99cf3612051b7c543 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 15:41:21 -0400 Subject: [PATCH 27/45] bug fix --- src/distribution.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution.jl b/src/distribution.jl index 7dd5aa11b..ce5455b78 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -16,7 +16,7 @@ abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end @inline get_trace_type(::Dist) where {T, Dist <: Distribution{T}} = DistributionTrace{T, Dist} function Base.convert(::Type{DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} - DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, dist(tr)) + DistributionTrace(convert(U, tr.val), tr.args, dist(tr)) end """ From 47cca5980c8b7be38b6215b4b49a69101c2a8173 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 16:42:28 -0400 Subject: [PATCH 28/45] remove ChoiceAt; bug fixes --- src/distribution.jl | 5 +- src/dsl/static.jl | 10 +- src/modeling_library/call_at/call_at.jl | 50 +++++- src/modeling_library/choice_at/choice_at.jl | 177 -------------------- src/modeling_library/modeling_library.jl | 1 - src/static_ir/update.jl | 4 +- test/modeling_library/call_at.jl | 2 +- test/modeling_library/choice_at.jl | 6 +- test/static_dsl.jl | 12 +- 9 files changed, 60 insertions(+), 207 deletions(-) delete mode 100644 src/modeling_library/choice_at/choice_at.jl diff --git a/src/distribution.jl b/src/distribution.jl index ce5455b78..6354558d7 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -10,13 +10,14 @@ end @inline dist(::DistributionTrace{T, Dist}) where {T, Dist} = Dist() abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end +DistributionTrace{T, Dist}(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...)) @inline DistributionTrace(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...)) # we need to know the specific distribution in the trace type so the compiler can specialize GFI calls fully @inline get_trace_type(::Dist) where {T, Dist <: Distribution{T}} = DistributionTrace{T, Dist} -function Base.convert(::Type{DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} - DistributionTrace(convert(U, tr.val), tr.args, dist(tr)) +function Base.convert(::Type{<:DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} + DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, tr.score) end """ diff --git a/src/dsl/static.jl b/src/dsl/static.jl index 8cafd9f94..c9ed6958a 100644 --- a/src/dsl/static.jl +++ b/src/dsl/static.jl @@ -51,10 +51,6 @@ end split_addr!(keys, addr_expr::QuoteNode) = push!(keys, addr_expr) split_addr!(keys, addr_expr::Symbol) = push!(keys, addr_expr) -"Construct choice-at or call-at combinator depending on type." -choice_or_call_at(gen_fn::GenerativeFunction, addr_typ) = call_at(gen_fn, addr_typ) -choice_or_call_at(dist::Distribution, addr_typ) = choice_at(dist, addr_typ) - "Generate informative node name for a Julia expression." gen_node_name(arg::Any) = gensym(string(arg)) gen_node_name(arg::Expr) = gensym(arg.head) @@ -78,12 +74,12 @@ function parse_trace_expr!(stmts, bindings, fn, args, addr) end addr = keys[1].value # Get top level address if length(keys) > 1 - # For each nesting level, wrap gen_fn_or_dist within choice_at / call_at + # For each nesting level, wrap gen_fn_or_dist within call_at for key in keys[2:end] push!(stmts, :($(esc(gen_fn_or_dist)) = - choice_or_call_at($(esc(gen_fn_or_dist)), Any))) + call_at($(esc(gen_fn_or_dist)), Any))) end - # Append the nested addresses as arguments to choice_at / call_at + # Append the nested addresses as arguments to call_at args = [args; reverse(keys[2:end])] end # Handle arguments to the traced call diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index f17d061f8..f997ba394 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -140,18 +140,50 @@ function regenerate(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, end function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad) - subselection = selection[trace.key] - (kernel_input_grads, value_submap, gradient_submap) = choice_gradients( - trace.subtrace, subselection, retval_grad) - input_grads = (kernel_input_grads..., nothing) - value_choices = CallAtChoiceMap(trace.key, value_submap) - gradient_choices = CallAtChoiceMap(trace.key, gradient_submap) - (input_grads, value_choices, gradient_choices) + if trace.subtrace isa DistributionTrace + if retval_grad !== nothing && !has_output_grad(get_gen_fn(trace.subtrace)) + error("return value gradient not accepted but one was provided") + end + kernel_arg_grads = logpdf_grad(get_gen_fn(trace.subtrace), get_retval(trace.subtrace), get_args(trace.subtrace)...) + if trace.key in selection + value_choices = CallAtChoiceMap(trace.key, get_choices(trace.subtrace)) + choice_grad = kernel_arg_grads[1] + if choice_grad === nothing + error("gradient not available for selected choice") + end + if retval_grad !== nothing + choice_grad += retval_grad + end + gradient_choices = CallAtChoiceMap(trace.key, ValueChoiceMap(choice_grad)) + else + value_choices = EmptyChoiceMap() + gradient_choices = EmptyChoiceMap() + end + input_grads = (kernel_arg_grads[2:end]..., nothing) + return (input_grads, value_choices, gradient_choices) + else + subselection = selection[trace.key] + (kernel_input_grads, value_submap, gradient_submap) = choice_gradients( + trace.subtrace, subselection, retval_grad) + input_grads = (kernel_input_grads..., nothing) + value_choices = CallAtChoiceMap(trace.key, value_submap) + gradient_choices = CallAtChoiceMap(trace.key, gradient_submap) + return (input_grads, value_choices, gradient_choices) + end end function accumulate_param_gradients!(trace::CallAtTrace, retval_grad) - kernel_input_grads = accumulate_param_gradients!(trace.subtrace, retval_grad) - (kernel_input_grads..., nothing) + if trace.subtrace isa DistributionTrace + if retval_grad !== nothing && !has_output_grad(trace.gen_fn.dist) + error("return value gradient not accepted but one was provided") + end + kernel_arg_grads = logpdf_grad(get_gen_fn(trace.subtrace), get_retval(trace.subtrace), get_args(trace.subtrace)...) + return (kernel_arg_grads[2:end]..., nothing) + else + kernel_input_grads = accumulate_param_gradients!(trace.subtrace, retval_grad) + return (kernel_input_grads..., nothing) + end + end export call_at diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl deleted file mode 100644 index f38758956..000000000 --- a/src/modeling_library/choice_at/choice_at.jl +++ /dev/null @@ -1,177 +0,0 @@ -# TODO optimize ChoiceAtTrace using type parameters - -struct ChoiceAtTrace <: Trace - gen_fn::GenerativeFunction # the ChoiceAtCombinator (not the kernel) - value::Any - key::Any - kernel_args::Tuple - score::Float64 -end - -get_args(trace::ChoiceAtTrace) = (trace.kernel_args..., trace.key) -get_retval(trace::ChoiceAtTrace) = trace.value -get_score(trace::ChoiceAtTrace) = trace.score -get_gen_fn(trace::ChoiceAtTrace) = trace.gen_fn - -struct ChoiceAtChoiceMap{T,K} <: ChoiceMap - key::K - value::T -end - -get_choices(trace::ChoiceAtTrace) = ChoiceAtChoiceMap(trace.key, trace.value) -Base.isempty(::ChoiceAtChoiceMap) = false -function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap} - SingleDynamicKeyAddressSchema() -end -get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr) -has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr) -get_submap(choices::ChoiceAtChoiceMap, addr::Pair) = _get_submap(choices, addr) -function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K} - choices.key == addr ? choices.value : throw(KeyError(choices, addr)) -end -get_submap(choices::ChoiceAtChoiceMap, addr) = addr == choices.key ? ValueChoiceMap(choices.value) : EmptyChoiceMap() -get_submaps_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, ValueChoiceMap(choices.value)),) -get_values_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, choices.value),) - -struct ChoiceAtCombinator{T,K} <: GenerativeFunction{T, ChoiceAtTrace} - dist::Distribution{T} -end - -accepts_output_grad(gen_fn::ChoiceAtCombinator) = has_output_grad(gen_fn.dist) - -# TODO -# accepts_output_grad is true if the return value is dependent on the 'gradient source elements' -# if the random choice itself is not a 'gradient source element' then it is independent (false) -# if the random choice is a 'gradient source element', then the return value is dependent (true) -# we will consider the random choice as a gradient source element if the -# distribution has has_output_grad = true) - -function choice_at(dist::Distribution{T}, ::Type{K}) where {T,K} - ChoiceAtCombinator{T,K}(dist) -end - -unpack_choice_at_args(args) = (args[end], args[1:end-1]) - -function assess(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - value = get_value(choices, key) - weight = logpdf(gen_fn.dist, value, kernel_args...) - (weight, value) -end - -function propose(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - value = random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - choices = ChoiceAtChoiceMap(key, value) - (choices, score, value) -end - -function simulate(gen_fn::ChoiceAtCombinator, args::Tuple) - (key, kernel_args) = unpack_choice_at_args(args) - value = random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - ChoiceAtTrace(gen_fn, value, key, kernel_args, score) -end - -function generate(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - constrained = has_value(choices, key) - value = constrained ? get_value(choices, key) : random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - trace = ChoiceAtTrace(gen_fn, value, key, kernel_args, score) - weight = constrained ? score : 0. - (trace, weight) -end - -function project(trace::ChoiceAtTrace, selection::Selection) - (trace.key in selection) ? trace.score : 0. -end - -function update(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, - choices::ChoiceMap) - (key, kernel_args) = unpack_choice_at_args(args) - key_changed = (key != trace.key) - constrained = has_value(choices, key) - if key_changed && constrained - new_value = get_value(choices, key) - discard = ChoiceAtChoiceMap(trace.key, trace.value) - elseif !key_changed && constrained - new_value = get_value(choices, key) - discard = ChoiceAtChoiceMap(key, trace.value) - elseif !key_changed && !constrained - new_value = trace.value - discard = EmptyChoiceMap() - else - error("New address $key not constrained in update") - end - new_score = logpdf(trace.gen_fn.dist, new_value, kernel_args...) - new_trace = ChoiceAtTrace(trace.gen_fn, new_value, key, kernel_args, new_score) - weight = new_score - trace.score - (new_trace, weight, UnknownChange(), discard) -end - -function regenerate(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, - selection::Selection) - (key, kernel_args) = unpack_choice_at_args(args) - key_changed = (key != trace.key) - selected = key in selection - if !key_changed && selected - new_value = random(trace.gen_fn.dist, kernel_args...) - elseif !key_changed && !selected - new_value = trace.value - elseif key_changed && !selected - new_value = random(trace.gen_fn.dist, kernel_args...) - else - error("Cannot select new address $key in regenerate") - end - new_score = logpdf(trace.gen_fn.dist, new_value, kernel_args...) - if !key_changed && selected - weight = 0. - elseif !key_changed && !selected - weight = new_score - trace.score - elseif key_changed && !selected - weight = 0. - end - new_trace = ChoiceAtTrace(trace.gen_fn, new_value, key, kernel_args, new_score) - (new_trace, weight, UnknownChange()) -end - -function choice_gradients(trace::ChoiceAtTrace, selection::Selection, retval_grad) - if retval_grad != nothing && !has_output_grad(trace.gen_fn.dist) - error("return value gradient not accepted but one was provided") - end - kernel_arg_grads = logpdf_grad(trace.gen_fn.dist, trace.value, trace.kernel_args...) - if trace.key in selection - value_choices = ChoiceAtChoiceMap(trace.key, trace.value) - choice_grad = kernel_arg_grads[1] - if choice_grad == nothing - error("gradient not available for selected choice") - end - if retval_grad != nothing - choice_grad += retval_grad - end - gradient_choices = ChoiceAtChoiceMap(trace.key, choice_grad) - else - value_choices = EmptyChoiceMap() - gradient_choices = EmptyChoiceMap() - end - input_grads = (kernel_arg_grads[2:end]..., nothing) - (input_grads, value_choices, gradient_choices) -end - -function accumulate_param_gradients!(trace::ChoiceAtTrace, retval_grad) - if retval_grad != nothing && !has_output_grad(trace.gen_fn.dist) - error("return value gradient not accepted but one was provided") - end - kernel_arg_grads = logpdf_grad(trace.gen_fn.dist, trace.value, trace.kernel_args...) - (kernel_arg_grads[2:end]..., nothing) -end - -export choice_at diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index 6200e5b52..5d9e287bb 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -19,7 +19,6 @@ include("dist_dsl/dist_dsl.jl") include("vector.jl") # built-in generative function combinators -include("choice_at/choice_at.jl") include("call_at/call_at.jl") include("map/map.jl") include("unfold/unfold.jl") diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index e3baf39ab..ea4b17f93 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -63,9 +63,11 @@ function cannot_statically_guarantee_nochange_retdiff(constraint_type, node, sta update_fn, Tuple{trace_type, Tuple, argdiff_type, constraint_type} ) - has_static_retdiff = update_rettype <: Tuple && update_rettype != Union{} && length(update_rettype.parameters) > 3 + has_static_retdiff = update_rettype <: Tuple && update_rettype != Union{} && length(update_rettype.parameters) >= 3 guaranteed_returns_nochange = has_static_retdiff && update_rettype.parameters[3] == NoChange + # println("$trace_type, Tuple, $argdiff_type, $constraint_type >> $update_rettype : $has_static_retdiff") + return !guaranteed_returns_nochange end diff --git a/test/modeling_library/call_at.jl b/test/modeling_library/call_at.jl index 607eb61fd..1985c610a 100644 --- a/test/modeling_library/call_at.jl +++ b/test/modeling_library/call_at.jl @@ -1,4 +1,4 @@ -@testset "call_at combinator" begin +@testset "call_at combinator on non-distribution" begin @gen (grad) function foo((grad)(x::Float64)) return x + @trace(normal(x, 1), :y) diff --git a/test/modeling_library/choice_at.jl b/test/modeling_library/choice_at.jl index 4f5241381..69eb52498 100644 --- a/test/modeling_library/choice_at.jl +++ b/test/modeling_library/choice_at.jl @@ -1,6 +1,6 @@ -@testset "choice_at combinator" begin +@testset "call_at combinator on distribution" begin - at = choice_at(bernoulli, Int) + at = call_at(bernoulli, Int) @testset "assess" begin choices = choicemap() @@ -143,7 +143,7 @@ y = 1.2 constraints = choicemap() set_value!(constraints, 3, y) - (trace, _) = generate(choice_at(normal, Int), (0.0, 1.0, 3), constraints) + (trace, _) = generate(call_at(normal, Int), (0.0, 1.0, 3), constraints) # not selected (input_grads, choices, gradients) = choice_gradients( diff --git a/test/static_dsl.jl b/test/static_dsl.jl index 5576aa7df..5d311df6f 100644 --- a/test/static_dsl.jl +++ b/test/static_dsl.jl @@ -40,13 +40,13 @@ end ret = @trace(bernoulli(0.5), :x => i) end -# @trace(choice_at(bernoulli)(0.5, i), :x) +# @trace(call_at(bernoulli)(0.5, i), :x) @gen (static) function at_choice_example_2(i::Int) ret = @trace(bernoulli(0.5), :x => i => :y) end -# @trace(call_at(choice_at(bernoulli))(0.5, i, :y), :x) +# @trace(call_at(call_at(bernoulli))(0.5, i, :y), :x) @gen function foo(mu) @trace(normal(mu, 1), :y) @@ -255,8 +255,8 @@ ret = get_node_by_addr(ir, :x) @test isa(ret.inputs[1], Gen.JuliaNode) # () -> 0.5 @test ret.inputs[2] === i at = ret.generative_function -@test isa(at, Gen.ChoiceAtCombinator) -@test at.dist == bernoulli +@test isa(at, Gen.CallAtCombinator) +@test at.kernel == bernoulli # at_choice_example_2 ir = Gen.get_ir(typeof(at_choice_example_2)) @@ -271,8 +271,8 @@ ret = get_node_by_addr(ir, :x) at = ret.generative_function @test isa(at, Gen.CallAtCombinator) at2 = at.kernel -@test isa(at2, Gen.ChoiceAtCombinator) -@test at2.dist == bernoulli +@test isa(at2, Gen.CallAtCombinator) +@test at2.kernel == bernoulli end From 10df9520aa91cd7dd4846058628ecf82d4adb168 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 17:06:22 -0400 Subject: [PATCH 29/45] decrease iters on benchmark --- test/benchmarks/dynamic_mh.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/benchmarks/dynamic_mh.jl b/test/benchmarks/dynamic_mh.jl index fb7661a7d..88392cc06 100644 --- a/test/benchmarks/dynamic_mh.jl +++ b/test/benchmarks/dynamic_mh.jl @@ -70,8 +70,8 @@ end println("Simple dynamic DSL MH on regression model:") (xs, ys) = make_data_set(200) do_inference(xs, ys, 10) -@time do_inference(xs, ys, 50) -@time do_inference(xs, ys, 50) +@time do_inference(xs, ys, 20) +@time do_inference(xs, ys, 20) println() end \ No newline at end of file From a79390e1fdd255283637a81ee9826c813adcc538 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Fri, 3 Jul 2020 12:46:42 -0400 Subject: [PATCH 30/45] merge in updated master --- README.md | 2 +- src/dsl/dsl.jl | 83 ++++++++++++++++++++++++++++++----- src/dsl/static.jl | 12 +++-- src/static_ir/backprop.jl | 14 +++--- src/static_ir/generate.jl | 18 ++++---- src/static_ir/project.jl | 12 ++--- src/static_ir/simulate.jl | 10 ++--- src/static_ir/static_ir.jl | 41 ++++------------- src/static_ir/trace.jl | 28 ++++++------ src/static_ir/update.jl | 78 ++++++++++++++++---------------- test/benchmarks/dynamic_mh.jl | 2 +- test/benchmarks/static_mh.jl | 2 +- test/static_dsl.jl | 21 +++++++++ test/tilde_sugar.jl | 46 ++++++++++++------- 14 files changed, 222 insertions(+), 147 deletions(-) diff --git a/README.md b/README.md index 95a8cbfc8..c2989f855 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Gen.jl -[![Build Status](https://travis-ci.org/probcomp/Gen.svg?branch=master)](https://travis-ci.org/probcomp/Gen.jl) +[![Build Status](https://travis-ci.org/probcomp/Gen.jl.svg?branch=master)](https://travis-ci.org/probcomp/Gen.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://probcomp.github.io/Gen.jl/stable) [![](https://img.shields.io/badge/docs-dev-blue.svg)](https://probcomp.github.io/Gen.jl/dev) diff --git a/src/dsl/dsl.jl b/src/dsl/dsl.jl index f3f5bdab5..a4f3c6b2f 100644 --- a/src/dsl/dsl.jl +++ b/src/dsl/dsl.jl @@ -5,6 +5,7 @@ const DSL_ARG_GRAD_ANNOTATION = :grad const DSL_RET_GRAD_ANNOTATION = :grad const DSL_TRACK_DIFFS_ANNOTATION = :diffs const DSL_NO_JULIA_CACHE_ANNOTATION = :nojuliacache +const DSL_MACROS = Set([Symbol("@trace"), Symbol("@param")]) struct Argument name::Symbol @@ -71,21 +72,81 @@ function address_from_expression(lhs) end function desugar_tildes(expr) + trace_ref = GlobalRef(@__MODULE__, Symbol("@trace")) + line_num = LineNumberNode(1, :none) MacroTools.postwalk(expr) do e + # Replace with globally referenced macrocalls if MacroTools.@capture(e, {*} ~ rhs_) - :(@trace($rhs)) + Expr(:macrocall, trace_ref, line_num, rhs) elseif MacroTools.@capture(e, {addr_} ~ rhs_) - :(@trace($rhs, $(addr))) + Expr(:macrocall, trace_ref, line_num, rhs, addr) elseif MacroTools.@capture(e, lhs_ ~ rhs_) - addr_expr = address_from_expression(lhs) - :($lhs = @trace($rhs, $(addr_expr))) + addr = address_from_expression(lhs) + Expr(:(=), lhs, Expr(:macrocall, trace_ref, line_num, rhs, addr)) else e end end end -function parse_gen_function(ast, annotations) +function resolve_gen_macros(expr, __module__) + MacroTools.postwalk(expr) do e + # Resolve Gen macros to globally referenced macrocalls + if (MacroTools.@capture(e, @namespace_.m_(args__)) && + m in DSL_MACROS && __module__.eval(namespace) == @__MODULE__) + macro_ref = GlobalRef(@__MODULE__, m) + line_num = e.args[2] + Expr(:macrocall, macro_ref, line_num, args...) + elseif (MacroTools.@capture(e, @m_(args__)) && + m in DSL_MACROS && isdefined(__module__, m) && + getfield(__module__, m) == getfield(@__MODULE__, m)) + macro_ref = GlobalRef(@__MODULE__, m) + line_num = e.args[2] + Expr(:macrocall, macro_ref, line_num, args...) + else + e + end + end +end + +function extract_quoted_exprs(expr) + quoted_exprs = [] + expr = MacroTools.prewalk(expr) do e + if MacroTools.@capture(e, :(quoted_)) && !isa(quoted, Symbol) + push!(quoted_exprs, e) + Expr(:placeholder, length(quoted_exprs)) + else + e + end + end + return expr, quoted_exprs +end + +function insert_quoted_exprs(expr, quoted_exprs) + expr = MacroTools.prewalk(expr) do e + if MacroTools.@capture(e, p_placeholder) + idx = p.args[1] + quoted_exprs[idx] + else + e + end + end + return expr +end + +function preprocess_body(expr, __module__) + # Protect quoted expressions from pre-processing by extracting them + expr, quoted_exprs = extract_quoted_exprs(expr) + # Desugar tilde calls to globally referenced @trace calls + expr = desugar_tildes(expr) + # Also resolve Gen macros to GlobalRefs for consistent downstream parsing + expr = resolve_gen_macros(expr, __module__) + # Reinsert quoted expressions after pre-processing + expr = insert_quoted_exprs(expr, quoted_exprs) + return expr +end + +function parse_gen_function(ast, annotations, __module__) ast = MacroTools.longdef(ast) if ast.head != :function error("syntax error at $ast in $(ast.head)") @@ -94,7 +155,6 @@ function parse_gen_function(ast, annotations) error("syntax error at $ast in $(ast.args)") end signature = ast.args[1] - body = desugar_tildes(ast.args[2]) if signature.head == :(::) (call_signature, return_type) = signature.args elseif signature.head == :call @@ -102,6 +162,7 @@ function parse_gen_function(ast, annotations) else error("syntax error at $(signature)") end + body = preprocess_body(ast.args[2], __module__) name = call_signature.args[1] args = map(parse_arg, call_signature.args[2:end]) static = DSL_STATIC_ANNOTATION in annotations @@ -112,15 +173,13 @@ function parse_gen_function(ast, annotations) end end -macro gen(annotations_expr, ast) - +macro gen(annotations_expr, ast::Expr) # parse the annotations annotations = parse_annotations(annotations_expr) - # parse the function definition - parse_gen_function(ast, annotations) + parse_gen_function(ast, annotations, __module__) end -macro gen(ast) - parse_gen_function(ast, Set{Symbol}()) +macro gen(ast::Expr) + parse_gen_function(ast, Set{Symbol}(), __module__) end diff --git a/src/dsl/static.jl b/src/dsl/static.jl index c9ed6958a..812aad78a 100644 --- a/src/dsl/static.jl +++ b/src/dsl/static.jl @@ -202,10 +202,12 @@ end "Parse and rewrite expression if it matches an @trace call." function parse_and_rewrite_trace!(stmts, bindings, expr) - if MacroTools.@capture(expr, @m_(f_(xs__), addr_)) && m == STATIC_DSL_TRACE + if (MacroTools.@capture(expr, @m_(f_(xs__), addr_)) && isa(m, GlobalRef) && + m.name == STATIC_DSL_TRACE && m.mod == @__MODULE__) # Parse "@trace(f(xs...), addr)" and return fresh variable parse_trace_expr!(stmts, bindings, f, xs, addr) - elseif MacroTools.@capture(expr, @m_(f_(xs__))) && m == STATIC_DSL_TRACE + elseif (MacroTools.@capture(expr, @m_(f_(xs__))) && isa(m, GlobalRef) && + m.name == STATIC_DSL_TRACE && m.mod == @__MODULE__) # Throw error for @trace expression without address static_dsl_syntax_error(expr, "Address required.") else @@ -219,12 +221,14 @@ function parse_static_dsl_line!(stmts, bindings, line) rewritten = MacroTools.postwalk( e -> parse_and_rewrite_trace!(stmts, bindings, e), line) # If line is a top-level @trace call, we are done - if MacroTools.@capture(line, @m_(f_(x__), a_)) && m == STATIC_DSL_TRACE + if (MacroTools.@capture(line, @m_(f_(x__), a_)) && isa(m, GlobalRef) && + m.name == STATIC_DSL_TRACE && m.mod == @__MODULE__) return end # Match and parse any other top-level expressions line = rewritten - if MacroTools.@capture(line, @m_ expr_) && m == STATIC_DSL_PARAM + if (MacroTools.@capture(line, @m_(expr_)) && isa(m, GlobalRef) && + m.name == STATIC_DSL_PARAM && m.mod == @__MODULE__) # Parse "@param var::T" parse_param_line!(stmts, bindings, expr) elseif MacroTools.@capture(line, lhs_ = rhs_) diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index eba97a82a..38bf2ccad 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -119,7 +119,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) else # regular forward execution. - + # we need the value for initializing gradient to zero (to get the type # and e.g. shape), and for reference by other nodes during # back_codegen! we could be more selective about which JuliaNodes need @@ -266,7 +266,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, subtrace_fieldname = get_subtrace_fieldname(node) call_selection = gensym("call_selection") if node in selected_calls - push!(stmts, :($call_selection = $qn_static_getindex(selection, $(QuoteNode(Val(node.addr)))))) + push!(stmts, :($call_selection = $(GlobalRef(Gen, :static_getindex))(selection, $(QuoteNode(Val(node.addr)))))) else push!(stmts, :($call_selection = EmptySelection())) end @@ -425,7 +425,7 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # assemble value_trie and gradient_trie value_trie = gensym("value_trie") gradient_trie = gensym("gradient_trie") - push!(stmts, generate_value_gradient_trie(selected_choices, selected_calls, + push!(stmts, generate_value_gradient_trie(selected_choices, selected_calls, value_trie, gradient_trie)) # gradients with respect to inputs @@ -434,7 +434,7 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # return values push!(stmts, :(return ($input_grads, $value_trie, $gradient_trie))) - + Expr(:block, stmts...) end @@ -486,20 +486,20 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, # return values push!(stmts, :(return $input_grads)) - + Expr(:block, stmts...) end push!(generated_functions, quote -@generated function $(Expr(:(.), Gen, QuoteNode(:choice_gradients)))(trace::T, selection::$(QuoteNode(Selection)), +@generated function $(GlobalRef(Gen, :choice_gradients))(trace::T, selection::$(QuoteNode(Selection)), retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} $(QuoteNode(codegen_choice_gradients))(trace, selection, retval_grad) end end) push!(generated_functions, quote -@generated function $(Expr(:(.), Gen, QuoteNode(:accumulate_param_gradients!)))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} +@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} $(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad) end end) diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index 2beecca54..a7e9d42ad 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -32,16 +32,16 @@ function process!(state::StaticIRGenerateState, node::GenerativeFunctionCallNode incr = gensym("weight") subconstraints = gensym("subconstraints") if isa(schema, StaticAddressSchema) && (node.addr in keys(schema)) - push!(state.stmts, :($subconstraints = $qn_static_get_submap(constraints, Val($addr)))) - push!(state.stmts, :(($subtrace, $incr) = $qn_generate($gen_fn, $args_tuple, $subconstraints))) + push!(state.stmts, :($subconstraints = $(GlobalRef(Gen, :static_get_submap))(constraints, Val($addr)))) + push!(state.stmts, :(($subtrace, $incr) = $(GlobalRef(Gen, :generate))($gen_fn, $args_tuple, $subconstraints))) else - push!(state.stmts, :(($subtrace, $incr) = $qn_generate($gen_fn, $args_tuple, $qn_empty_choice_map))) + push!(state.stmts, :(($subtrace, $incr) = $(GlobalRef(Gen, :generate))($gen_fn, $args_tuple, $(GlobalRef(Gen, :EmptyChoiceMap))()))) end push!(state.stmts, :($weight += $incr)) - push!(state.stmts, :($num_nonempty_fieldname += !$qn_isempty($qn_get_choices($subtrace)) ? 1 : 0)) - push!(state.stmts, :($(node.name) = $qn_get_retval($subtrace))) - push!(state.stmts, :($total_score_fieldname += $qn_get_score($subtrace))) - push!(state.stmts, :($total_noise_fieldname += $qn_project($subtrace, $qn_empty_selection))) + push!(state.stmts, :($num_nonempty_fieldname += !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) ? 1 : 0)) + push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) + push!(state.stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace))) + push!(state.stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()))) end function codegen_generate(gen_fn_type::Type{T}, args, @@ -51,7 +51,7 @@ function codegen_generate(gen_fn_type::Type{T}, args, # convert the constraints to a static assignment if it is not already one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema)) - return quote $qn_generate(gen_fn, args, $(QuoteNode(StaticChoiceMap))(constraints)) end + return quote $(GlobalRef(Gen, :generate))(gen_fn, args, $(QuoteNode(StaticChoiceMap))(constraints)) end end ir = get_ir(gen_fn_type) @@ -88,7 +88,7 @@ function codegen_generate(gen_fn_type::Type{T}, args, end push!(generated_functions, quote -@generated function $(Expr(:(.), Gen, QuoteNode(:generate)))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), +@generated function $(GlobalRef(Gen, :generate))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap))) $(QuoteNode(codegen_generate))(gen_fn, args, constraints) end diff --git a/src/static_ir/project.jl b/src/static_ir/project.jl index ed14a2f29..62df493ab 100644 --- a/src/static_ir/project.jl +++ b/src/static_ir/project.jl @@ -12,10 +12,10 @@ function process!(state::StaticIRProjectState, node::GenerativeFunctionCallNode) subtrace = get_subtrace_fieldname(node) subselection = gensym("subselection") if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.stmts, :($subselection = $qn_static_getindex(selection, Val($addr)))) - push!(state.stmts, :($weight += $qn_project(trace.$subtrace, $subselection))) + push!(state.stmts, :($subselection = $(GlobalRef(Gen, :static_getindex))(selection, Val($addr)))) + push!(state.stmts, :($weight += $(GlobalRef(Gen, :project))(trace.$subtrace, $subselection))) else - push!(state.stmts, :($weight += $qn_project(trace.$subtrace, $qn_empty_selection))) + push!(state.stmts, :($weight += $(GlobalRef(Gen, :project))(trace.$subtrace, $(GlobalRef(Gen, :EmptySelection))()))) end end @@ -25,7 +25,7 @@ function codegen_project(trace_type::Type, selection_type::Type) # convert the selection to a static selection if it is not already one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema)) - return quote $qn_project(trace, $(QuoteNode(StaticSelection))(selection)) end + return quote $(GlobalRef(Gen, :project))(trace, $(QuoteNode(StaticSelection))(selection)) end end ir = get_ir(gen_fn_type) @@ -48,11 +48,11 @@ end push!(generated_functions, quote -@generated function $(Expr(:(.), Gen, QuoteNode(:project)))(trace::T, selection::$(QuoteNode(Selection))) where {T <: $(QuoteNode(StaticIRTrace))} +@generated function $(GlobalRef(Gen, :project))(trace::T, selection::$(QuoteNode(Selection))) where {T <: $(QuoteNode(StaticIRTrace))} $(QuoteNode(codegen_project))(trace, selection) end -function $(Expr(:(.), Gen, QuoteNode(:project)))(trace::T, selection::$(QuoteNode(EmptySelection))) where {T <: $(QuoteNode(StaticIRTrace))} +function $(GlobalRef(Gen, :project))(trace::T, selection::$(QuoteNode(EmptySelection))) where {T <: $(QuoteNode(StaticIRTrace))} trace.$total_noise_fieldname end diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index 267183ac1..669f40753 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -28,10 +28,10 @@ function process!(state::StaticIRSimulateState, node::GenerativeFunctionCallNode subtrace = get_subtrace_fieldname(node) incr = gensym("weight") push!(state.stmts, :($subtrace = $(QuoteNode(simulate))($gen_fn, $args_tuple))) - push!(state.stmts, :($num_nonempty_fieldname += !$qn_isempty($qn_get_choices($subtrace)) ? 1 : 0)) - push!(state.stmts, :($(node.name) = $qn_get_retval($subtrace))) - push!(state.stmts, :($total_score_fieldname += $qn_get_score($subtrace))) - push!(state.stmts, :($total_noise_fieldname += $qn_project($subtrace, $qn_empty_selection))) + push!(state.stmts, :($num_nonempty_fieldname += !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) ? 1 : 0)) + push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) + push!(state.stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace))) + push!(state.stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()))) end function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenerativeFunction} @@ -70,7 +70,7 @@ function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenera end push!(generated_functions, quote -@generated function $(Expr(:(.), Gen, QuoteNode(:simulate)))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::Tuple) +@generated function $(GlobalRef(Gen, :simulate))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::Tuple) $(QuoteNode(codegen_simulate))(gen_fn, args) end end) diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index f63fbcdd2..5b156d0aa 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -51,14 +51,14 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati params_grad::Dict{Symbol,Any} params::Dict{Symbol,Any} end - (gen_fn::$gen_fn_type_name)(args...) = propose(gen_fn, args)[3] - $(Expr(:(.), Gen, QuoteNode(:get_ir)))(::Type{$gen_fn_type_name}) = $(QuoteNode(ir)) - $(Expr(:(.), Gen, QuoteNode(:get_trace_type)))(::Type{$gen_fn_type_name}) = $trace_struct_name - $(Expr(:(.), Gen, QuoteNode(:has_argument_grads)))(::$gen_fn_type_name) = $(QuoteNode(has_argument_grads)) - $(Expr(:(.), Gen, QuoteNode(:accepts_output_grad)))(::$gen_fn_type_name) = $(QuoteNode(accepts_output_grad)) - $(Expr(:(.), Gen, QuoteNode(:get_gen_fn)))(trace::$trace_struct_name) = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref))) - $(Expr(:(.), Gen, QuoteNode(:get_gen_fn_type)))(::Type{$trace_struct_name}) = $gen_fn_type_name - $(Expr(:(.), Gen, QuoteNode(:get_options)))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) + (gen_fn::$gen_fn_type_name)(args...) = $(GlobalRef(Gen, :propose))(gen_fn, args)[3] + $(GlobalRef(Gen, :get_ir))(::Type{$gen_fn_type_name}) = $(QuoteNode(ir)) + $(GlobalRef(Gen, :get_trace_type))(::Type{$gen_fn_type_name}) = $trace_struct_name + $(GlobalRef(Gen, :has_argument_grads))(::$gen_fn_type_name) = $(QuoteNode(has_argument_grads)) + $(GlobalRef(Gen, :accepts_output_grad))(::$gen_fn_type_name) = $(QuoteNode(accepts_output_grad)) + $(GlobalRef(Gen, :get_gen_fn))(trace::$trace_struct_name) = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref))) + $(GlobalRef(Gen, :get_gen_fn_type))(::Type{$trace_struct_name}) = $gen_fn_type_name + $(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) end Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) end @@ -74,31 +74,6 @@ const trace = gensym("trace") const weight = gensym("weight") const subtrace = gensym("subtrace") -# quoted values and function called in generated code (since generated code is -# evaluted in the user's Main module, not Gen) -const qn_isempty = QuoteNode(isempty) -const qn_get_score = QuoteNode(get_score) -const qn_get_retval = QuoteNode(get_retval) -const qn_project = QuoteNode(project) -const qn_logpdf = QuoteNode(logpdf) -const qn_get_choices = QuoteNode(get_choices) -const qn_random = QuoteNode(random) -const qn_simulate = QuoteNode(simulate) -const qn_generate = QuoteNode(generate) -const qn_update = QuoteNode(update) -const qn_regenerate = QuoteNode(regenerate) -const qn_strip_diff = QuoteNode(strip_diff) -const qn_get_diff = QuoteNode(get_diff) -const qn_Diffed = QuoteNode(Diffed) -const qn_unknown_change = QuoteNode(UnknownChange()) -const qn_no_change = QuoteNode(NoChange()) -const qn_get_internal_node = QuoteNode(get_internal_node) -const qn_static_get_value = QuoteNode(static_get_value) -const qn_static_get_submap = QuoteNode(static_get_submap) -const qn_static_getindex = QuoteNode(static_getindex) # for getting a subselection -const qn_empty_choice_map = QuoteNode(EmptyChoiceMap()) -const qn_empty_selection = QuoteNode(EmptySelection()) - include("simulate.jl") include("generate.jl") include("project.jl") diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 56209e846..a38cfdbd7 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -104,7 +104,7 @@ end function generate_get_score(trace_struct_name::Symbol) Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_score)), :(trace::$trace_struct_name)), + Expr(:call, GlobalRef(Gen, :get_score), :(trace::$trace_struct_name)), Expr(:block, :(trace.$total_score_fieldname))) end @@ -112,13 +112,13 @@ function generate_get_args(ir::StaticIR, trace_struct_name::Symbol) args = Expr(:tuple, [:(trace.$(get_value_fieldname(node))) for node in ir.arg_nodes]...) Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_args)), :(trace::$trace_struct_name)), + Expr(:call, GlobalRef(Gen, :get_args), :(trace::$trace_struct_name)), Expr(:block, args)) end function generate_get_retval(ir::StaticIR, trace_struct_name::Symbol) Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_retval)), :(trace::$trace_struct_name)), + Expr(:call, GlobalRef(Gen, :get_retval), :(trace::$trace_struct_name)), Expr(:block, :(trace.$return_value_fieldname))) end @@ -127,32 +127,32 @@ function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) for node in ir.call_nodes addr = node.addr subtrace = :(choices.trace.$(get_subtrace_fieldname(node))) - push!(elements, :(($(QuoteNode(addr)), get_choices($subtrace)))) + push!(elements, :(($(QuoteNode(addr)), $(GlobalRef(Gen, :get_choices))($subtrace)))) end - Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_submaps_shallow)), + Expr(:function, + Expr(:call, GlobalRef(Gen, :get_submaps_shallow), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name})), Expr(:block, Expr(:tuple, elements...))) end -function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) +function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) get_subtrace_exprs = Expr[] for node in ir.call_nodes push!(get_subtrace_exprs, quote - function Gen.static_get_subtrace(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) + function $(GlobalRef(Gen, :static_get_subtrace))(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) return trace.$(get_subtrace_fieldname(node)) end end ) end - + call_getindex_exprs = Expr[] for node in ir.call_nodes push!(call_getindex_exprs, quote - function Gen.static_getindex(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) - return get_retval(trace.$(get_subtrace_fieldname(node))) + function $(GlobalRef(Gen, :static_getindex))(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) + return $(GlobalRef(Gen, :get_retval))(trace.$(get_subtrace_fieldname(node))) end end ) @@ -165,11 +165,11 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) methods = Expr[] for node in ir.call_nodes push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_submap)), + Expr(:call, GlobalRef(Gen, :static_get_submap), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), :(::Val{$(QuoteNode(node.addr))})), Expr(:block, - :(get_choices(choices.trace.$(get_subtrace_fieldname(node))))))) + :($(GlobalRef(Gen, :get_choices))(choices.trace.$(get_subtrace_fieldname(node))))))) end methods @@ -178,7 +178,7 @@ end function generate_get_schema(ir::StaticIR, trace_struct_name::Symbol) addrs = [QuoteNode(node.addr) for node in ir.call_nodes] Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_schema)), :(::Type{$trace_struct_name})), + Expr(:call, GlobalRef(Gen, :get_schema), :(::Type{$trace_struct_name})), Expr(:block, :($(QuoteNode(StaticAddressSchema))( Set{Symbol}([$(addrs...)]))))) diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index ea4b17f93..b927072ea 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -99,7 +99,7 @@ end # this pass is used to determine which JuliaNodes need to be re-run (their # return value is not currently cached in the trace) -struct BackwardPassState +struct BackwardPassState marked::Set{StaticIRNode} end @@ -136,8 +136,8 @@ end ######################## function arg_values_and_diffs_from_tracked_diffs(input_nodes) - arg_values = map((node) -> Expr(:call, qn_strip_diff, node.name), input_nodes) - arg_diffs = map((node) -> Expr(:call, qn_get_diff, node.name), input_nodes) + arg_values = map((node) -> Expr(:call, (GlobalRef(Gen, :strip_diff)), node.name), input_nodes) + arg_diffs = map((node) -> Expr(:call, (GlobalRef(Gen, :get_diff)), node.name), input_nodes) (arg_values, arg_diffs) end @@ -151,7 +151,7 @@ end function process_codegen!(stmts, ::ForwardPassState, ::BackwardPassState, node::ArgumentNode, ::AbstractUpdateMode, options) if options.track_diffs - push!(stmts, :($(get_value_fieldname(node)) = $qn_strip_diff($(node.name)))) + push!(stmts, :($(get_value_fieldname(node)) = $(GlobalRef(Gen, :strip_diff))($(node.name)))) else push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) end @@ -166,13 +166,13 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, # track diffs if run_it arg_values, arg_diffs = arg_values_and_diffs_from_tracked_diffs(node.inputs) - args = map((v, d) -> Expr(:call, qn_Diffed, v, d), arg_values, arg_diffs) + args = map((v, d) -> Expr(:call, (GlobalRef(Gen, :Diffed)), v, d), arg_values, arg_diffs) push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(args...)))) elseif options.cache_julia_nodes - push!(stmts, :($(node.name) = $qn_Diffed(trace.$(get_value_fieldname(node)), $qn_no_change))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), $(GlobalRef(Gen, :NoChange))()))) end if options.cache_julia_nodes - push!(stmts, :($(get_value_fieldname(node)) = $qn_strip_diff($(node.name)))) + push!(stmts, :($(get_value_fieldname(node)) = $(GlobalRef(Gen, :strip_diff))($(node.name)))) end else @@ -206,30 +206,30 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, call_constraints = gensym("call_constraints") if node in fwd.constrained_or_selected_calls || node in fwd.input_changed if node in fwd.constrained_or_selected_calls - push!(stmts, :($call_constraints = $qn_static_get_submap(constraints, Val($addr)))) + push!(stmts, :($call_constraints = $(GlobalRef(Gen, :static_get_submap))(constraints, Val($addr)))) else - push!(stmts, :($call_constraints = $qn_empty_choice_map)) + push!(stmts, :($call_constraints = $(GlobalRef(Gen, :EmptyChoiceMap))())) end - push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node)), $(call_discard_var(node))) = - $qn_update($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_constraints))) + push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node)), $(call_discard_var(node))) = + $(GlobalRef(Gen, :update))($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_constraints))) push!(stmts, :($weight += $call_weight)) - push!(stmts, :($total_score_fieldname += $qn_get_score($subtrace) - $qn_get_score($prev_subtrace))) - push!(stmts, :($total_noise_fieldname += $qn_project($subtrace, $qn_empty_selection) - $qn_project($prev_subtrace, $qn_empty_selection))) - push!(stmts, :(if !$qn_isempty($qn_get_choices($subtrace)) && $qn_isempty($qn_get_choices($prev_subtrace)) + push!(stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace) - $(GlobalRef(Gen, :get_score))($prev_subtrace))) + push!(stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()) - $(GlobalRef(Gen, :project))($prev_subtrace, $(GlobalRef(Gen, :EmptySelection))()))) + push!(stmts, :(if !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) && $(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($prev_subtrace)) $num_nonempty_fieldname += 1 end)) - push!(stmts, :(if $qn_isempty($qn_get_choices($subtrace)) && !$qn_isempty($qn_get_choices($prev_subtrace)) + push!(stmts, :(if $(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) && !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($prev_subtrace)) $num_nonempty_fieldname -= 1 end)) if options.track_diffs - push!(stmts, :($(node.name) = $qn_Diffed($qn_get_retval($subtrace), $(calldiff_var(node))))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :get_retval))($subtrace), $(calldiff_var(node))))) else - push!(stmts, :($(node.name) = $qn_get_retval($subtrace))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) end else push!(stmts, :($subtrace = $prev_subtrace)) if options.track_diffs - push!(stmts, :($(node.name) = $qn_Diffed($qn_get_retval($subtrace), $(QuoteNode(NoChange()))))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :get_retval))($subtrace), $(QuoteNode(NoChange()))))) else - push!(stmts, :($(node.name) = $qn_get_retval($subtrace))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) end end end @@ -251,30 +251,30 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, call_subselection = gensym("call_subselection") if node in fwd.constrained_or_selected_calls || node in fwd.input_changed if node in fwd.constrained_or_selected_calls - push!(stmts, :($call_subselection = $qn_static_getindex(selection, Val($addr)))) + push!(stmts, :($call_subselection = $(GlobalRef(Gen, :static_getindex))(selection, Val($addr)))) else - push!(stmts, :($call_subselection = $qn_empty_selection)) + push!(stmts, :($call_subselection = $(GlobalRef(Gen, :EmptySelection))())) end - push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node))) = - $qn_regenerate($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_subselection))) + push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node))) = + $(GlobalRef(Gen, :regenerate))($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_subselection))) push!(stmts, :($weight += $call_weight)) - push!(stmts, :($total_score_fieldname += $qn_get_score($subtrace) - $qn_get_score($prev_subtrace))) - push!(stmts, :($total_noise_fieldname += $qn_project($subtrace, $qn_empty_selection) - $qn_project($prev_subtrace, $qn_empty_selection))) - push!(stmts, :(if !$qn_isempty($qn_get_choices($subtrace)) && !$qn_isempty($qn_get_choices($prev_subtrace)) + push!(stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace) - $(GlobalRef(Gen, :get_score))($prev_subtrace))) + push!(stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()) - $(GlobalRef(Gen, :project))($prev_subtrace, $(GlobalRef(Gen, :EmptySelection))()))) + push!(stmts, :(if !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) && !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($prev_subtrace)) $num_nonempty_fieldname += 1 end)) - push!(stmts, :(if $qn_isempty($qn_get_choices($subtrace)) && !$qn_isempty($qn_get_choices($prev_subtrace)) + push!(stmts, :(if $(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) && !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($prev_subtrace)) $num_nonempty_fieldname -= 1 end)) if options.track_diffs - push!(stmts, :($(node.name) = $qn_Diffed($qn_get_retval($subtrace), $(calldiff_var(node))))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :get_retval))($subtrace), $(calldiff_var(node))))) else - push!(stmts, :($(node.name) = $qn_get_retval($subtrace))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) end else push!(stmts, :($subtrace = $prev_subtrace)) if options.track_diffs - push!(stmts, :($(node.name) = $qn_Diffed($qn_get_retval($subtrace), $qn_no_change))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :get_retval))($subtrace), $(GlobalRef(Gen, :NoChange))()))) else - push!(stmts, :($(node.name) = $qn_get_retval($subtrace))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) end end end @@ -289,7 +289,7 @@ end function unpack_arguments!(stmts::Vector{Expr}, arg_nodes::Vector{ArgumentNode}, options) if options.track_diffs arg_names = Symbol[arg_node.name for arg_node in arg_nodes] - push!(stmts, :($(Expr(:tuple, arg_names...)) = $(QuoteNode(map))($qn_Diffed, args, argdiffs))) + push!(stmts, :($(Expr(:tuple, arg_names...)) = $(QuoteNode(map))($(GlobalRef(Gen, :Diffed)), args, argdiffs))) else arg_names = Symbol[arg_node.name for arg_node in arg_nodes] push!(stmts, :($(Expr(:tuple, arg_names...)) = args)) @@ -298,8 +298,8 @@ end function generate_return_value!(stmts::Vector{Expr}, fwd::ForwardPassState, return_node::StaticIRNode, options) if options.track_diffs - push!(stmts, :($return_value_fieldname = $qn_strip_diff($(return_node.name)))) - push!(stmts, :($retdiff = $qn_get_diff($(return_node.name)))) + push!(stmts, :($return_value_fieldname = $(GlobalRef(Gen, :strip_diff))($(return_node.name)))) + push!(stmts, :($retdiff = $(GlobalRef(Gen, :get_diff))($(return_node.name)))) else push!(stmts, :($return_value_fieldname = $(return_node.name))) push!(stmts, :($retdiff = $(QuoteNode(return_node in fwd.value_changed ? UnknownChange() : NoChange())))) @@ -309,7 +309,7 @@ end function generate_new_trace!(stmts::Vector{Expr}, trace_type::Type, options) if options.track_diffs # note that the generative function is the last field - constructor_args = map((name) -> Expr(:call, QuoteNode(strip_diff), name), + constructor_args = map((name) -> Expr(:call, QuoteNode(strip_diff), name), fieldnames(trace_type)[1:end-1]) push!(stmts, :($trace = $(QuoteNode(trace_type))($(constructor_args...), $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref)))))) @@ -347,7 +347,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ # convert the constraints to a static assignment if it is not already one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema)) - return quote $qn_update(trace, args, argdiffs, $(QuoteNode(StaticChoiceMap))(constraints)) end + return quote $(GlobalRef(Gen, :update))(trace, args, argdiffs, $(QuoteNode(StaticChoiceMap))(constraints)) end end ir = get_ir(gen_fn_type) @@ -395,7 +395,7 @@ function codegen_regenerate(trace_type::Type{T}, args_type::Type, argdiffs_type: # convert a hierarchical selection to a static selection if it is not alreay one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema)) - return quote $qn_regenerate(trace, args, argdiffs, $(QuoteNode(StaticSelection))(selection)) end + return quote $(GlobalRef(Gen, :regenerate))(trace, args, argdiffs, $(QuoteNode(StaticSelection))(selection)) end end ir = get_ir(gen_fn_type) @@ -433,14 +433,14 @@ end let T = gensym() push!(generated_functions, quote - @generated function $(Expr(:(.), Gen, QuoteNode(:update)))(trace::$T, args::Tuple, argdiffs::Tuple, + @generated function $(GlobalRef(Gen, :update))(trace::$T, args::Tuple, argdiffs::Tuple, constraints::$(QuoteNode(ChoiceMap))) where {$T<:$(QuoteNode(StaticIRTrace))} $(QuoteNode(codegen_update))(trace, args, argdiffs, constraints) end end) push!(generated_functions, quote - @generated function $(Expr(:(.), Gen, QuoteNode(:regenerate)))(trace::$T, args::Tuple, argdiffs::Tuple, + @generated function $(GlobalRef(Gen, :regenerate))(trace::$T, args::Tuple, argdiffs::Tuple, selection::$(QuoteNode(Selection))) where {$T<:$(QuoteNode(StaticIRTrace))} $(QuoteNode(codegen_regenerate))(trace, args, argdiffs, selection) end diff --git a/test/benchmarks/dynamic_mh.jl b/test/benchmarks/dynamic_mh.jl index 88392cc06..a31223f94 100644 --- a/test/benchmarks/dynamic_mh.jl +++ b/test/benchmarks/dynamic_mh.jl @@ -74,4 +74,4 @@ do_inference(xs, ys, 10) @time do_inference(xs, ys, 20) println() -end \ No newline at end of file +end diff --git a/test/benchmarks/static_mh.jl b/test/benchmarks/static_mh.jl index 0e801631c..8c4c5acad 100644 --- a/test/benchmarks/static_mh.jl +++ b/test/benchmarks/static_mh.jl @@ -84,4 +84,4 @@ println("Simple static DSL (including CallAt nodes) MH on regression model:") @time do_inference(xs, ys, 50) @time do_inference(xs, ys, 50) println() -end \ No newline at end of file +end diff --git a/test/static_dsl.jl b/test/static_dsl.jl index 5d311df6f..24a8072e1 100644 --- a/test/static_dsl.jl +++ b/test/static_dsl.jl @@ -567,4 +567,25 @@ tr, w = generate(MyModuleB.foo, (0,), choicemap(:y => 1)) end +@testset "static gen function choicemaps" begin +@gen (static) function bar2() + b ~ normal(0, 1) + return b +end +@gen (static) function bar1() + a ~ bar2() + x ~ normal(0, 1) + return x +end +Gen.load_generated_functions() +tr = simulate(bar1, ()) +ch = get_choices(tr) +@test has_value(ch, :x) +@test !has_value(ch, :y) +@test has_value(get_submap(ch, :a), :b) +@test get_submap(ch, :y) == EmptyChoiceMap() +@test length(collect(get_values_shallow(ch))) == 1 +@test length(collect(get_submaps_shallow(ch))) == 2 +end + end # @testset "static DSL" diff --git a/test/tilde_sugar.jl b/test/tilde_sugar.jl index fbd528b76..3f3501961 100644 --- a/test/tilde_sugar.jl +++ b/test/tilde_sugar.jl @@ -1,9 +1,10 @@ using Gen import MacroTools -normalize(ex) = MacroTools.prewalk(MacroTools.rmlines, ex) - +@testset "tilde syntax" begin +normalize(ex) = + MacroTools.prewalk(MacroTools.rmlines, Gen.resolve_gen_macros(ex, Main)) # dynamic @testset "tilde syntax smoke test (dynamic)" begin @@ -73,17 +74,32 @@ end @testset "tilde syntax desugars as expected (static)" begin -expected = normalize(:( -@gen (static) function foo() - x = @trace(normal(0, 1), :x) - y = @trace(normal(0, 1), :y) -end)) - -actual = normalize(Gen.desugar_tildes(:( -@gen (static) function foo() - x ~ normal(0, 1) - y = ({:y} ~ normal(0, 1)) -end))) - -@test actual == expected + expected = normalize(:( + @gen (static) function foo() + x = @trace(normal(0, 1), :x) + y = @trace(normal(0, 1), :y) + end)) + + actual = normalize(Gen.desugar_tildes(:( + @gen (static) function foo() + x ~ normal(0, 1) + y = ({:y} ~ normal(0, 1)) + end))) + + @test actual == expected +end + +@testset "tilde syntax preserved in quoted expressions" begin + @gen function tilde_expr() + return :(x ~ normal(0, 1)) + end + @test tilde_expr() == :(x ~ normal(0, 1)) + + @gen (static) function tilde_expr() + return :(x ~ normal(0, 1)) + end + Gen.load_generated_functions() + @test tilde_expr() == :(x ~ normal(0, 1)) +end + end From 4fca5690642e3ccc72969d3ccc0d7d28c2013644 Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Fri, 19 Aug 2022 13:42:57 -0400 Subject: [PATCH 31/45] Merge in additional changes from master --- src/choice_map/choice_map.jl | 23 ++++++++++++++++++++++- src/choice_map/dynamic_choice_map.jl | 6 +++++- src/distribution.jl | 2 +- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/choice_map/choice_map.jl b/src/choice_map/choice_map.jl index a1ca2eaef..7ce059cf8 100644 --- a/src/choice_map/choice_map.jl +++ b/src/choice_map/choice_map.jl @@ -75,6 +75,13 @@ function get_value end @inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) +""" + has_submap(choices::ChoiceMap, addr) +Return true if there is a non-empty sub-assignment at the given address. +""" +function has_submap end +@inline has_submap(choices::ChoiceMap, addr) = !isempty(get_submap(choices, addr)) + """ schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} @@ -198,6 +205,20 @@ function Base.:(==)(a::ChoiceMap, b::ChoiceMap) return true end +# This is modeled after +# https://github.com/JuliaLang/julia/blob/7bff5cdd0fab8d625e48b3a9bb4e94286f2ba18c/base/abstractdict.jl#L530-L537 +const hasha_seed = UInt === UInt64 ? 0x6d35bb51952d5539 : 0x952d5539 +function Base.hash(a::ChoiceMap, h::UInt) + hv = hasha_seed + for (addr, value) in get_values_shallow(a) + hv = xor(hv, hash(addr, hash(value))) + end + for (addr, submap) in get_submaps_shallow(a) + hv = xor(hv, hash(addr, hash(submap))) + end + return hash(hv, h) +end + function Base.isapprox(a::ChoiceMap, b::ChoiceMap) for (addr, submap) in get_submaps_shallow(a) if !isapprox(get_submap(b, addr), submap) @@ -271,7 +292,7 @@ function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) end export ChoiceMap, ValueChoiceMap, EmptyChoiceMap -export _get_submap, get_submap, get_submaps_shallow +export _get_submap, get_submap, get_submaps_shallow, has_submap export get_value, has_value export get_values_shallow, get_nonvalue_submaps_shallow export get_address_schema, get_selected diff --git a/src/choice_map/dynamic_choice_map.jl b/src/choice_map/dynamic_choice_map.jl index 0f27c89d7..112bfc92e 100644 --- a/src/choice_map/dynamic_choice_map.jl +++ b/src/choice_map/dynamic_choice_map.jl @@ -24,7 +24,11 @@ end function DynamicChoiceMap(tuples...) choices = DynamicChoiceMap() - for (addr, value) in tuples + for tuple in tuples + if length(tuple) != 2 + error("Constructor accepts tuples of the form (address, value) only") + end + (addr, value) = tuple choices[addr] = value end choices diff --git a/src/distribution.jl b/src/distribution.jl index 6354558d7..253116b0b 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -53,7 +53,7 @@ Otherwise, this element contains the gradient with respect to the `i`th argument """ function logpdf_grad end -function is_discrete end +is_discrete(::Distribution) = false # default # NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl From dc7d9a9174c25716e41873aca5e5fe47357d5f95 Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Fri, 19 Aug 2022 14:53:56 -0400 Subject: [PATCH 32/45] Remove methods referring to choice nodes --- src/dynamic/update.jl | 5 ----- src/static_ir/print_ir.jl | 5 ----- 2 files changed, 10 deletions(-) diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index b57729c0a..913801a56 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -94,11 +94,6 @@ function update_delete_recurse(prev_trie::Trie{Any,CallRecord}, score end -function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, - visited::AllSelection) - 0. -end - function add_unvisited_to_discard!(discard::DynamicChoiceMap, visited::DynamicSelection, prev_choices::ChoiceMap) diff --git a/src/static_ir/print_ir.jl b/src/static_ir/print_ir.jl index 0a8c52b88..f6291a700 100644 --- a/src/static_ir/print_ir.jl +++ b/src/static_ir/print_ir.jl @@ -27,10 +27,5 @@ function print_ir(io::IO, node::GenerativeFunctionCallNode) print(io, "$(node.name) = @trace($(gen_fn_name)($inputs), :$(node.addr))") end -function print_ir(io::IO, node::RandomChoiceNode) - inputs = join((string(i.name) for i in node.inputs), ", ") - print(io, "$(node.name) = @trace($(node.dist)($inputs), :$(node.addr))") -end - ir_name(fn::GenerativeFunction) = nameof(typeof(fn)) ir_name(fn::DynamicDSLFunction) = nameof(fn.julia_function) From fc03fbdbaee634a2a07dbc02a79b68570df9c350 Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Fri, 19 Aug 2022 14:54:24 -0400 Subject: [PATCH 33/45] Use Ints in trace translator tests for discrete values --- test/inference/trace_translators.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/inference/trace_translators.jl b/test/inference/trace_translators.jl index 6008ff5cf..fc2c365de 100644 --- a/test/inference/trace_translators.jl +++ b/test/inference/trace_translators.jl @@ -97,7 +97,7 @@ end else @write(p2_trace[:z], true, :discrete) x = @read(p1_trace[:x], :continuous) - i = ceil(x * 10) + i = Int(ceil(x * 10)) @write(p2_trace[:i], i, :discrete) @write(q2_trace[:dx], x - (i-1)/10, :continuous) end @@ -135,8 +135,8 @@ end @transform f (p1_trace, q1_trace) to (p2_trace, q2_trace) begin x = @read(p1_trace[:x], :continuous) y = @read(p1_trace[:y], :continuous) - i = ceil(x * 10) - j = ceil(y * 10) + i = Int(ceil(x * 10)) + j = Int(ceil(y * 10)) @write(p2_trace[:i], i, :discrete) @write(p2_trace[:j], j, :discrete) @write(q2_trace[:dx], x - (i-1)/10, :continuous) From 33b7c107cef1264bdf2c014930d56f7c6d4988bb Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Fri, 19 Aug 2022 14:54:45 -0400 Subject: [PATCH 34/45] Fix choicemap equality --- src/choice_map/choice_map.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/choice_map/choice_map.jl b/src/choice_map/choice_map.jl index 7ce059cf8..6e0f5bd7b 100644 --- a/src/choice_map/choice_map.jl +++ b/src/choice_map/choice_map.jl @@ -136,8 +136,6 @@ struct EmptyChoiceMap <: ChoiceMap end @inline get_submaps_shallow(::EmptyChoiceMap) = () @inline get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() @inline Base.:(==)(::EmptyChoiceMap, ::EmptyChoiceMap) = true -@inline Base.:(==)(::ChoiceMap, ::EmptyChoiceMap) = false -@inline Base.:(==)(::EmptyChoiceMap, ::ChoiceMap) = false """ ValueChoiceMap @@ -153,6 +151,8 @@ end @inline get_submap(choices::ValueChoiceMap, addr) = EmptyChoiceMap() @inline get_submaps_shallow(choices::ValueChoiceMap) = () @inline Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val +@inline Base.:(==)(a::ValueChoiceMap, b::ChoiceMap) = false +@inline Base.:(==)(a::ChoiceMap, b::ValueChoiceMap) = false @inline Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) @inline get_address_schema(::Type{<:ValueChoiceMap}) = AllAddressSchema() From 7780c70e78190de941074d9b701beddf6ee9120e Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Fri, 19 Aug 2022 14:55:04 -0400 Subject: [PATCH 35/45] update has_submap tests to reflect that ValueChoiceMaps are submaps --- test/assignment.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/assignment.jl b/test/assignment.jl index c151c4f9e..8ec08f992 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -264,7 +264,8 @@ end choices = choicemap() choices[:x] = 1 @test has_value(choices, :x) - @test !has_submap(choices, :x) + @test has_submap(choices, :x) + @test !has_submap(choices, :z) submap = choicemap(); submap[:y] = 2 set_submap!(choices, :x, submap) @test !has_value(choices, :x) @@ -282,7 +283,8 @@ end choices = choicemap() choices[:x => :y] = 1 @test has_submap(choices, :x) - @test !has_submap(choices, :x => :y) + @test has_submap(choices, :x => :y) # valuechoicemap + @test !has_submap(choices, :x => :z) submap = choicemap(); submap[:z] = 2 set_submap!(choices, :x, submap) @test !isempty(get_submap(choices, :x)) From 5ffad02d978e2d03fa42cba589e4516f79029dea Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Fri, 19 Aug 2022 14:55:39 -0400 Subject: [PATCH 36/45] fix some iterator issues --- src/modeling_library/switch/update.jl | 4 ++-- test/gen_fn_interface.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index 02923a776..853fd3aa7 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -19,14 +19,14 @@ function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap) # Add (address, value) to new_choices from prev_choices if address does not occur in choices. for (address, value) in prev_choice_value_iterator - address in keys(choice_value_iterator) && continue + address in map(first, collect(choice_value_iterator)) && continue set_value!(new_choices, address, value) end # Add (address, submap) to new_choices from prev_choices if address does not occur in choices. # If it does, enter a recursive call to update_recurse_merge. for (address, node1) in prev_choice_submap_iterator - if address in keys(choice_submap_iterator) + if address in map(first, collect(choice_submap_iterator)) node2 = get_submap(choices, address) node = update_recurse_merge(node1, node2) set_submap!(new_choices, address, node) diff --git a/test/gen_fn_interface.jl b/test/gen_fn_interface.jl index be41ae51f..c7bb0e9a1 100644 --- a/test/gen_fn_interface.jl +++ b/test/gen_fn_interface.jl @@ -19,7 +19,7 @@ for (lang, f) in [:dynamic => f_dynamic, # sanity-check that `update` did what it's supposed to. @test get_args(trace1) == (5, 6) @test trace1[:z] == 0 - @test :z in keys(get_values_shallow(discard)) + @test :z in map(first, get_values_shallow(discard)) end @testset "regenerate(...) shorthand assuming unchanged args ($lang modeling lang)" begin From 2845e3bffa67fe994775a43873eef7f56075d000 Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Fri, 19 Aug 2022 14:55:59 -0400 Subject: [PATCH 37/45] handle complement selections in project(::DistributionTrace, ...) --- src/distribution.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/distribution.jl b/src/distribution.jl index 253116b0b..6a88a661d 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -72,7 +72,10 @@ get_return_type(::Distribution{T}) where {T} = T @inline Gen.get_score(trace::DistributionTrace) = trace.score @inline Gen.project(trace::DistributionTrace, ::EmptySelection) = 0. @inline Gen.project(trace::DistributionTrace, ::AllSelection) = get_score(trace) - +@inline Gen.project(trace::DistributionTrace, c::ComplementSelection) = project_complement(trace, c.complement) +@inline project_complement(trace::DistributionTrace, ::EmptySelection) = get_score(trace) +@inline project_complement(trace::DistributionTrace, ::AllSelection) = 0. +@inline project_complement(trace::DistributionTrace, c::ComplementSelection) = project_complement(trace, c.complement) @inline function Gen.simulate(dist::Distribution, args::Tuple) val = random(dist, args...) DistributionTrace(val, args, dist) From 4c93efcece109181bc025fc96943f8251c0f626e Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Fri, 19 Aug 2022 16:27:33 -0400 Subject: [PATCH 38/45] Fix project_complement --- src/distribution.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution.jl b/src/distribution.jl index 6a88a661d..2f35cf0c1 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -75,7 +75,7 @@ get_return_type(::Distribution{T}) where {T} = T @inline Gen.project(trace::DistributionTrace, c::ComplementSelection) = project_complement(trace, c.complement) @inline project_complement(trace::DistributionTrace, ::EmptySelection) = get_score(trace) @inline project_complement(trace::DistributionTrace, ::AllSelection) = 0. -@inline project_complement(trace::DistributionTrace, c::ComplementSelection) = project_complement(trace, c.complement) +@inline project_complement(trace::DistributionTrace, c::ComplementSelection) = Gen.project(trace, c.complement) @inline function Gen.simulate(dist::Distribution, args::Tuple) val = random(dist, args...) DistributionTrace(val, args, dist) From d2e80f59148d9fd81780dce4fc06404b6db57f91 Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Sat, 20 Aug 2022 14:31:48 -0400 Subject: [PATCH 39/45] Make DistributionTrace store a reference to the distribution --- src/distribution.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/distribution.jl b/src/distribution.jl index 2f35cf0c1..7742ea9aa 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -6,18 +6,19 @@ struct DistributionTrace{T, Dist} <: Trace val::T args score::Float64 + dist::Dist end -@inline dist(::DistributionTrace{T, Dist}) where {T, Dist} = Dist() +@inline dist(tr::DistributionTrace{T, Dist}) where {T, Dist} = tr.dist abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end -DistributionTrace{T, Dist}(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...)) -@inline DistributionTrace(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...)) +DistributionTrace{T, Dist}(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...), dist) +@inline DistributionTrace(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...), dist) # we need to know the specific distribution in the trace type so the compiler can specialize GFI calls fully @inline get_trace_type(::Dist) where {T, Dist <: Distribution{T}} = DistributionTrace{T, Dist} function Base.convert(::Type{<:DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} - DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, tr.score) + DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, tr.score, tr.dist) end """ From e07701bd1bdcc936d05867481be71895d898157e Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Sat, 20 Aug 2022 14:32:19 -0400 Subject: [PATCH 40/45] Remove dependency of benchmarks on no-longer-extant examples directory --- test/benchmarks/dataset.jl | 20 ++++++++++++++++++++ test/benchmarks/dynamic_mh.jl | 32 ++++++++++++++++++++++++++++++-- test/benchmarks/static_mh.jl | 26 +++++++++++++++++++++++--- 3 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 test/benchmarks/dataset.jl diff --git a/test/benchmarks/dataset.jl b/test/benchmarks/dataset.jl new file mode 100644 index 000000000..5cd6308a7 --- /dev/null +++ b/test/benchmarks/dataset.jl @@ -0,0 +1,20 @@ + +function make_data_set(n) + Random.seed!(1) + prob_outlier = 0.5 + true_inlier_noise = 0.5 + true_outlier_noise = 5.0 + true_slope = -1 + true_intercept = 2 + xs = collect(range(-5, stop=5, length=n)) + ys = Float64[] + for (i, x) in enumerate(xs) + if rand() < prob_outlier + y = true_slope * x + true_intercept + randn() * true_inlier_noise + else + y = true_slope * x + true_intercept + randn() * true_outlier_noise + end + push!(ys, y) + end + (xs, ys) +end \ No newline at end of file diff --git a/test/benchmarks/dynamic_mh.jl b/test/benchmarks/dynamic_mh.jl index a31223f94..fdaaf54f9 100644 --- a/test/benchmarks/dynamic_mh.jl +++ b/test/benchmarks/dynamic_mh.jl @@ -2,8 +2,36 @@ module DynamicMHBenchmark using Gen import Random -include("../../examples/regression/dynamic_model.jl") -include("../../examples/regression/dataset.jl") +include("dataset.jl") + +######### +# model # +######### + +# TODO put this into FunctionalCollections: +import FunctionalCollections +Base.IndexStyle(::Type{<:FunctionalCollections.PersistentVector}) = IndexLinear() + +@gen function datum(x::Float64, (grad)(inlier_std::Float64), (grad)(outlier_std), (grad)(slope), (grad)(intercept))::Float64 + is_outlier = @trace(bernoulli(0.5), :z) + std = is_outlier ? inlier_std : outlier_std + y = @trace(normal(x * slope + intercept, std), :y) + return y +end + +data = Map(datum) + +@gen function model(xs::Vector{Float64}) + n = length(xs) + inlier_std = exp(@trace(normal(0, 2), :log_inlier_std)) + outlier_std = exp(@trace(normal(0, 2), :log_outlier_std)) + slope = @trace(normal(0, 2), :slope) + intercept = @trace(normal(0, 2), :intercept) + ys = @trace(data(xs, fill(inlier_std, n), fill(outlier_std, n), + fill(slope, n), fill(intercept, n)), :data) + return ys +end + @gen function slope_proposal(trace) slope = trace[:slope] diff --git a/test/benchmarks/static_mh.jl b/test/benchmarks/static_mh.jl index 8c4c5acad..6ae1545d2 100644 --- a/test/benchmarks/static_mh.jl +++ b/test/benchmarks/static_mh.jl @@ -1,9 +1,9 @@ module StaticMHBenchmark using Gen import Random +using Gen: ifelse -include("../../examples/regression/static_model.jl") -include("../../examples/regression/dataset.jl") +include("dataset.jl") @gen (static) function slope_proposal(trace) slope = trace[:slope] @@ -39,7 +39,27 @@ end @trace(bernoulli(prev_z ? 0.0 : 1.0), :data => i => :z) end -Gen.load_generated_functions() +@gen (static) function datum(x::Float64, (grad)(inlier_std::Float64), (grad)(outlier_std::Float64), + (grad)(slope::Float64), (grad)(intercept::Float64)) + is_outlier = @trace(bernoulli(0.5), :z) + std = ifelse(is_outlier, inlier_std, outlier_std) + y = @trace(normal(x * slope + intercept, std), :y) + return y +end + +data = Map(datum) + +@gen (static) function model(xs::Vector{Float64}) + n = length(xs) + inlier_log_std = @trace(normal(0, 2), :log_inlier_std) + outlier_log_std = @trace(normal(0, 2), :log_outlier_std) + inlier_std = exp(inlier_log_std) + outlier_std = exp(outlier_log_std) + slope = @trace(normal(0, 2), :slope) + intercept = @trace(normal(0, 2), :intercept) + @trace(data(xs, fill(inlier_std, n), fill(outlier_std, n), + fill(slope, n), fill(intercept, n)), :data) +end function do_inference(xs, ys, num_iters) observations = choicemap() From 577b289e3627f86ea870094c8bc065855498780a Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Sat, 20 Aug 2022 14:34:22 -0400 Subject: [PATCH 41/45] Implement `choice_gradients` and `accumulate_param_gradients!` for distributions --- src/distribution.jl | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/distribution.jl b/src/distribution.jl index 7742ea9aa..2f73260c6 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -38,7 +38,7 @@ function logpdf end """ has::Bool = has_output_grad(dist::Distribution) -Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice. +Return true if the distribution computes the gradient of the logpdf with respect to the value of the random choice. """ function has_output_grad end @@ -120,6 +120,45 @@ end (weight, choices.val) end + +# Gradient-based methods +@inline Gen.accepts_output_grad(dist::Distribution) = has_output_grad(dist) + +function Gen.choice_gradients(tr::DistributionTrace, ::AllSelection, retgrad) + if !has_output_grad(dist(tr)) + error("Distribution $(dist(tr)) does not compute gradient of logpdf with respect to value") + end + output_grad, arg_grads... = logpdf_grad(dist(tr), tr.val, tr.args...) + choice_values = ValueChoiceMap(tr.val) + choice_grads = ValueChoiceMap(output_grad + retgrad) + return arg_grads, choice_values, choice_grads +end + +@inline function Gen.choice_gradients(tr::DistributionTrace, ::EmptySelection, retgrad) + _, arg_grads... = logpdf_grad(dist(tr), tr.val, tr.args...) + choice_values = EmptyChoiceMap() + choice_grads = EmptyChoiceMap() + return arg_grads, choice_values, choice_grads +end + +function Gen.choice_gradients(tr::DistributionTrace, c::ComplementSelection, retgrad) + if c.complement isa EmptySelection + return choice_gradients(tr, AllSelection(), retgrad) + elseif c.complement isa AllSelection + return choice_gradients(tr, EmptySelection(), retgrad) + elseif c.complement isa ComplementSelection + return choice_gradients(tr, c.complement.complement, retgrad) + else + error("Choice gradients not implemented for generic complement selection") + end +end + +@inline function Gen.accumulate_param_gradients!(tr::DistributionTrace, retgrad, scale_factor) + _, arg_grads... = logpdf_grad(dist(tr), tr.val, tr.args...) + return arg_grads +end + + ########### # Exports # ########### From f7044a214181588b5502ac08857013ecbdf47f39 Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Sat, 20 Aug 2022 14:34:42 -0400 Subject: [PATCH 42/45] Delete `load_generated_functions` calls from benchmarks --- test/static_choicemap_benchmark.jl | 2 -- test/static_inference_benchmark.jl | 2 -- 2 files changed, 4 deletions(-) diff --git a/test/static_choicemap_benchmark.jl b/test/static_choicemap_benchmark.jl index 1e62b9a8e..7b62fe8cf 100644 --- a/test/static_choicemap_benchmark.jl +++ b/test/static_choicemap_benchmark.jl @@ -34,8 +34,6 @@ end b ~ inner() end -load_generated_functions() - tr, _ = generate(outer, ()) choices = get_choices(tr) diff --git a/test/static_inference_benchmark.jl b/test/static_inference_benchmark.jl index b70d08be2..953af236b 100644 --- a/test/static_inference_benchmark.jl +++ b/test/static_inference_benchmark.jl @@ -6,8 +6,6 @@ using Gen c ~ normal(b, 1) end -@load_generated_functions - observations = StaticChoiceMap(choicemap((:b,2), (:c,1.5))) tr, _ = generate(foo, (), observations) From 09dc792fa6fa24c6ebec3341d2ff213690ef554b Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Sat, 20 Aug 2022 14:35:42 -0400 Subject: [PATCH 43/45] Remove special handling of `Distribution` from dynamic DSL backprop code --- src/dynamic/backprop.jl | 63 ----------------------------------------- 1 file changed, 63 deletions(-) diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index e57d19a32..238c173d7 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -2,34 +2,6 @@ function maybe_track(arg, has_argument_grad::Bool, tape) has_argument_grad ? track(arg, tape) : arg end -@noinline function ReverseDiff.special_reverse_exec!( - instruction::ReverseDiff.SpecialInstruction{D}) where {D <: Distribution} - dist::D = instruction.func - args_maybe_tracked = instruction.input - score_tracked = instruction.output - arg_grads = logpdf_grad(dist, map(value, args_maybe_tracked)...) - value_tracked = args_maybe_tracked[1] - value_grad = arg_grads[1] - if istracked(value_tracked) - if has_output_grad(dist) - increment_deriv!(value_tracked, value_grad * deriv(score_tracked)) - else - error("Gradient required but not available for return value of distribution $dist") - end - end - for (i, (arg_maybe_tracked, grad, has_grad)) in enumerate( - zip(args_maybe_tracked[2:end], arg_grads[2:end], has_argument_grads(dist))) - if istracked(arg_maybe_tracked) - if has_grad - increment_deriv!(arg_maybe_tracked, grad * deriv(score_tracked)) - else - error("Gradient required but not available for argument $i of $dist") - end - end - end - nothing -end - ################### # accumulate_param_gradients! # @@ -70,19 +42,6 @@ function read_param(state::GFBackpropParamsState, name::Symbol) value end -function traceat(state::GFBackpropParamsState, dist::Distribution{T}, - args_maybe_tracked, key) where {T} - local retval::T - visit!(state.visitor, key) - retval = get_retval(get_call(state.trace, key).subtrace) - args = map(value, args_maybe_tracked) - score_tracked = track(logpdf(dist, retval, args...), state.tape) - record!(state.tape, ReverseDiff.SpecialInstruction, dist, - (retval, args_maybe_tracked...,), score_tracked) - state.score += score_tracked - retval -end - struct BackpropParamsRecord gen_fn::GenerativeFunction subtrace::Any @@ -271,28 +230,6 @@ function fill_map!( fill_submaps!(map, tracked_trie, mode) end -function traceat(state::GFBackpropTraceState, dist::Distribution{T}, - args_maybe_tracked, key) where {T} - local retval::T - visit!(state.visitor, key) - retval = get_retval(get_call(state.trace, key).subtrace) - args = map(value, args_maybe_tracked) - score_tracked = track(logpdf(dist, retval, args...), state.tape) - if key in state.selection - tracked_retval = track(retval, state.tape) - set_leaf_node!(state.tracked_choices, key, tracked_retval) - record!(state.tape, ReverseDiff.SpecialInstruction, dist, - (tracked_retval, args_maybe_tracked...,), score_tracked) - state.score += score_tracked - return tracked_retval - else - record!(state.tape, ReverseDiff.SpecialInstruction, dist, - (retval, args_maybe_tracked...,), score_tracked) - state.score += score_tracked - return retval - end -end - struct BackpropTraceRecord gen_fn::GenerativeFunction subtrace::Any From 370b8227ceed2b3803e421e711d3c01a49cae7e5 Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Sat, 20 Aug 2022 15:00:57 -0400 Subject: [PATCH 44/45] Remove special case handling of distributions from static DSL backprop methods --- src/distribution.jl | 2 +- src/modeling_library/call_at/call_at.jl | 50 +----- src/static_ir/backprop.jl | 210 ++++++++---------------- 3 files changed, 76 insertions(+), 186 deletions(-) diff --git a/src/distribution.jl b/src/distribution.jl index 2f73260c6..e81d7f75c 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -130,7 +130,7 @@ function Gen.choice_gradients(tr::DistributionTrace, ::AllSelection, retgrad) end output_grad, arg_grads... = logpdf_grad(dist(tr), tr.val, tr.args...) choice_values = ValueChoiceMap(tr.val) - choice_grads = ValueChoiceMap(output_grad + retgrad) + choice_grads = ValueChoiceMap(isnothing(retgrad) ? output_grad : output_grad .+ retgrad) return arg_grads, choice_values, choice_grads end diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index eb690f8b8..cbd9ce400 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -140,50 +140,18 @@ function regenerate(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, end function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad) - if trace.subtrace isa DistributionTrace - if retval_grad !== nothing && !has_output_grad(get_gen_fn(trace.subtrace)) - error("return value gradient not accepted but one was provided") - end - kernel_arg_grads = logpdf_grad(get_gen_fn(trace.subtrace), get_retval(trace.subtrace), get_args(trace.subtrace)...) - if trace.key in selection - value_choices = CallAtChoiceMap(trace.key, get_choices(trace.subtrace)) - choice_grad = kernel_arg_grads[1] - if choice_grad === nothing - error("gradient not available for selected choice") - end - if retval_grad !== nothing - choice_grad += retval_grad - end - gradient_choices = CallAtChoiceMap(trace.key, ValueChoiceMap(choice_grad)) - else - value_choices = EmptyChoiceMap() - gradient_choices = EmptyChoiceMap() - end - input_grads = (kernel_arg_grads[2:end]..., nothing) - return (input_grads, value_choices, gradient_choices) - else - subselection = selection[trace.key] - (kernel_input_grads, value_submap, gradient_submap) = choice_gradients( - trace.subtrace, subselection, retval_grad) - input_grads = (kernel_input_grads..., nothing) - value_choices = CallAtChoiceMap(trace.key, value_submap) - gradient_choices = CallAtChoiceMap(trace.key, gradient_submap) - return (input_grads, value_choices, gradient_choices) - end + subselection = selection[trace.key] + (kernel_input_grads, value_submap, gradient_submap) = choice_gradients( + trace.subtrace, subselection, retval_grad) + input_grads = (kernel_input_grads..., nothing) + value_choices = CallAtChoiceMap(trace.key, value_submap) + gradient_choices = CallAtChoiceMap(trace.key, gradient_submap) + return (input_grads, value_choices, gradient_choices) end function accumulate_param_gradients!(trace::CallAtTrace, retval_grad) - if trace.subtrace isa DistributionTrace - if retval_grad !== nothing && !has_output_grad(trace.gen_fn.dist) - error("return value gradient not accepted but one was provided") - end - kernel_arg_grads = logpdf_grad(get_gen_fn(trace.subtrace), get_retval(trace.subtrace), get_args(trace.subtrace)...) - return (kernel_arg_grads[2:end]..., nothing) - else - kernel_input_grads = accumulate_param_gradients!(trace.subtrace, retval_grad) - return (kernel_input_grads..., nothing) - end - + kernel_input_grads = accumulate_param_gradients!(trace.subtrace, retval_grad) + return (kernel_input_grads..., nothing) end export call_at diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index fec07a7ec..1107380b9 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -19,32 +19,26 @@ maybe_tracked_value_var(node::JuliaNode) = Symbol("$(maybe_tracked_value_prefix) const maybe_tracked_arg_prefix = gensym("maybe_tracked_arg") maybe_tracked_arg_var(node::JuliaNode, i::Int) = Symbol("$(maybe_tracked_arg_prefix)_$(node.name)_$i") -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode) +function fwd_pass!(selected_calls, fwd_marked, node::TrainableParameterNode) # TODO: only need to mark it if we are doing backprop params push!(fwd_marked, node) end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::ArgumentNode) +function fwd_pass!(selected_calls, fwd_marked, node::ArgumentNode) if node.compute_grad push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode) +function fwd_pass!(selected_calls, fwd_marked, node::JuliaNode) if any(input_node in fwd_marked for input_node in node.inputs) push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode) - if node.generative_function isa Distribution - if node in selected_choices - push!(fwd_marked, node) - end - else - if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) - push!(fwd_marked, node) - end +function fwd_pass!(selected_calls, fwd_marked, node::GenerativeFunctionCallNode) + if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) + push!(fwd_marked, node) end end @@ -65,10 +59,6 @@ function back_pass!(back_marked, node::GenerativeFunctionCallNode) for input_node in node.inputs push!(back_marked, input_node) end - if node.generative_function isa Distribution - # the value of every random choice is in back_marked, since it affects its logpdf - push!(back_marked, node) - end end function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNode) @@ -130,34 +120,16 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) end function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) - if node.generative_function isa Distribution - # for reference by other nodes during back_codegen! - # could performance optimize this away - push!(stmts, :($(node.name) = get_retval(trace.$(get_subtrace_fieldname(node))))) - - # every random choice is in back_marked, since it affects it logpdf, but - # also possibly due to other downstream usage of the value - @assert node in back_marked + # for reference by other nodes during back_codegen! + # could performance optimize this away + subtrace_fieldname = get_subtrace_fieldname(node) + push!(stmts, :($(node.name) = get_retval(trace.$subtrace_fieldname))) - if node in fwd_marked - # the only way we are fwd_marked is if this choice was selected - - # initialize gradient with respect to the value of the random choice to zero - # it will be a runtime error, thrown here, if there is no zero() method - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) - end - else - # for reference by other nodes during back_codegen! - # could performance optimize this away - subtrace_fieldname = get_subtrace_fieldname(node) - push!(stmts, :($(node.name) = get_retval(trace.$subtrace_fieldname))) - - # NOTE: we will still potentially run choice_gradients recursively on the generative function, - # we just might not use its return value gradient. - if node in fwd_marked && node in back_marked - # we are fwd_marked if an input was fwd_marked, or if we were selected internally - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) - end + # NOTE: we will still potentially run choice_gradients recursively on the generative function, + # we just might not use its return value gradient. + if node in fwd_marked && node in back_marked + # we are fwd_marked if an input was fwd_marked, or if we were selected internally + push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) end end @@ -239,125 +211,76 @@ end function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropTraceMode) - if node.generative_function isa Distribution - logpdf_grad = gensym("logpdf_grad") - - # backpropagate to the inputs - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) - - # backpropagate to the value (if it was selected) - if node in fwd_marked - if !has_output_grad(node.generative_function) - error("Distribution $(node.generative_function) does not logpdf gradient for its output value") - end - push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) - end - else - # handle case when it is the return node - if node === ir.return_node && node in fwd_marked - @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) - end + # handle case when it is the return node + if node === ir.return_node && node in fwd_marked + @assert node in back_marked + push!(stmts, :($(gradient_var(node)) += retval_grad)) + end - if node in fwd_marked - input_grads = gensym("call_input_grads") - value_trie = value_trie_var(node) - gradient_trie = gradient_trie_var(node) - subtrace_fieldname = get_subtrace_fieldname(node) - call_selection = gensym("call_selection") - if node in selected_calls - push!(stmts, :($call_selection = $(GlobalRef(Gen, :static_getindex))(selection, $(QuoteNode(Val(node.addr)))))) - else - push!(stmts, :($call_selection = EmptySelection())) - end - retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( - trace.$subtrace_fieldname, $call_selection, $retval_grad))) + if node in fwd_marked + input_grads = gensym("call_input_grads") + value_trie = value_trie_var(node) + gradient_trie = gradient_trie_var(node) + subtrace_fieldname = get_subtrace_fieldname(node) + call_selection = gensym("call_selection") + if node in selected_calls + push!(stmts, :($call_selection = $(GlobalRef(Gen, :static_getindex))(selection, $(QuoteNode(Val(node.addr)))))) + else + push!(stmts, :($call_selection = EmptySelection())) end + retval_grad = node in back_marked ? gradient_var(node) : :(nothing) + push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( + trace.$subtrace_fieldname, $call_selection, $retval_grad))) + end - # increment gradients of input nodes that are in fwd_marked - for (i, input_node) in enumerate(node.inputs) - if input_node in fwd_marked - @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) - end + # increment gradients of input nodes that are in fwd_marked + for (i, input_node) in enumerate(node.inputs) + if input_node in fwd_marked + @assert input_node in back_marked # this ensured its gradient will have been initialized + push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) end - - # NOTE: the value_trie and gradient_trie are dealt with later end + + # NOTE: the value_trie and gradient_trie are dealt with later end function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropParamsMode) - if node.generative_function isa Distribution - logpdf_grad = gensym("logpdf_grad") - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) - else - # handle case when it is the return node - if node === ir.return_node && node in fwd_marked - @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) - end + # handle case when it is the return node + if node === ir.return_node && node in fwd_marked + @assert node in back_marked + push!(stmts, :($(gradient_var(node)) += retval_grad)) + end - if node in fwd_marked - input_grads = gensym("call_input_grads") - subtrace_fieldname = get_subtrace_fieldname(node) - retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :($input_grads = accumulate_param_gradients!(trace.$subtrace_fieldname, $retval_grad))) - end + if node in fwd_marked + input_grads = gensym("call_input_grads") + subtrace_fieldname = get_subtrace_fieldname(node) + retval_grad = node in back_marked ? gradient_var(node) : :(nothing) + push!(stmts, :($input_grads = accumulate_param_gradients!(trace.$subtrace_fieldname, $retval_grad))) + end - # increment gradients of input nodes that are in fwd_marked - for (i, input_node) in enumerate(node.inputs) - if input_node in fwd_marked - @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) - end + # increment gradients of input nodes that are in fwd_marked + for (i, input_node) in enumerate(node.inputs) + if input_node in fwd_marked + @assert input_node in back_marked # this ensured its gradient will have been initialized + push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) end end end -function generate_value_gradient_trie(selected_choices::Set{GenerativeFunctionCallNode}, - selected_calls::Set{GenerativeFunctionCallNode}, +function generate_value_gradient_trie(selected_calls::Set{GenerativeFunctionCallNode}, value_trie::Symbol, gradient_trie::Symbol) - selected_choices_vec = collect(selected_choices) - quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) - leaf_value_choicemaps = map((node) -> :(ValueChoiceMap(get_retval(trace.$(get_subtrace_fieldname(node))))), selected_choices_vec) - leaf_gradient_choicemaps = map((node) -> :(ValueChoiceMap($(gradient_var(node)))), selected_choices_vec) - selected_calls_vec = collect(selected_calls) quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec) internal_value_choicemaps = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), selected_calls_vec) internal_gradient_choicemaps = map((node) -> gradient_trie_var(node), selected_calls_vec) - quoted_all_keys = Iterators.flatten((quoted_leaf_keys, quoted_internal_keys)) - all_value_choicemaps = Iterators.flatten((leaf_value_choicemaps, internal_value_choicemaps)) - all_gradient_choicemaps = Iterators.flatten((leaf_gradient_choicemaps, internal_gradient_choicemaps)) - quote - $value_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_value_choicemaps...),))) - $gradient_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_gradient_choicemaps...),))) - end -end - -function get_selected_choices(::EmptyAddressSchema, ::StaticIR) - Set{GenerativeFunctionCallNode}() -end - -function get_selected_choices(::AllAddressSchema, ir::StaticIR) - Set{GenerativeFunctionCallNode}([node for node in ir.call_nodes if node.generative_function isa Distribution]...) -end - -function get_selected_choices(schema::StaticAddressSchema, ir::StaticIR) - selected_choice_addrs = Set(keys(schema)) - selected_choices = Set{GenerativeFunctionCallNode}() - for node in ir.call_nodes - if node.generative_function isa Distribution && node.addr in selected_choice_addrs - push!(selected_choices, node) - end + $value_trie = StaticChoiceMap(NamedTuple{($(quoted_internal_keys...),)}(($(internal_value_choicemaps...),))) + $gradient_trie = StaticChoiceMap(NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradient_choicemaps...),))) end - selected_choices end function get_selected_calls(::EmptyAddressSchema, ::StaticIR) @@ -365,14 +288,14 @@ function get_selected_calls(::EmptyAddressSchema, ::StaticIR) end function get_selected_calls(::AllAddressSchema, ir::StaticIR) - Set{GenerativeFunctionCallNode}([node for node in ir.call_nodes if !(node.generative_function isa Distribution)]...) + Set{GenerativeFunctionCallNode}(ir.call_nodes...) end function get_selected_calls(schema::StaticAddressSchema, ir::StaticIR) selected_call_addrs = Set(keys(schema)) selected_calls = Set{GenerativeFunctionCallNode}() for node in ir.call_nodes - if !(node.generative_function isa Distribution) && node.addr in selected_call_addrs + if node.addr in selected_call_addrs push!(selected_calls, node) end end @@ -390,13 +313,12 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, end ir = get_ir(gen_fn_type) - selected_choices = get_selected_choices(schema, ir) selected_calls = get_selected_calls(schema, ir) # forward marking pass fwd_marked = Set{StaticIRNode}() for node in ir.nodes - fwd_pass!(selected_choices, selected_calls, fwd_marked, node) + fwd_pass!(selected_calls, fwd_marked, node) end # backward marking pass @@ -425,7 +347,7 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # assemble value_trie and gradient_trie value_trie = gensym("value_trie") gradient_trie = gensym("gradient_trie") - push!(stmts, generate_value_gradient_trie(selected_choices, selected_calls, + push!(stmts, generate_value_gradient_trie(selected_calls, value_trie, gradient_trie)) # gradients with respect to inputs @@ -444,7 +366,7 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, ir = get_ir(gen_fn_type) # unlike choice_gradients we don't take gradients w.r.t. the value of random choices - selected_choices = Set{GenerativeFunctionCallNode}() + # selected_choices = Set{GenerativeFunctionCallNode}() # we need to guarantee that we visit every generative function call, # because we need to backpropagate to its trainable parameters @@ -454,7 +376,7 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, # forward marking pass fwd_marked = Set{StaticIRNode}() for node in ir.nodes - fwd_pass!(selected_choices, selected_calls, fwd_marked, node) + fwd_pass!(selected_calls, fwd_marked, node) end # backward marking pass From a09d1cc4d1887438ae8c284a7000ae35a5079127 Mon Sep 17 00:00:00 2001 From: Alex Lew Date: Sat, 20 Aug 2022 15:02:23 -0400 Subject: [PATCH 45/45] Avoid "slurping" destructuring assignment, which is unavailable in older Julias --- src/distribution.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/distribution.jl b/src/distribution.jl index e81d7f75c..fdb505512 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -128,14 +128,16 @@ function Gen.choice_gradients(tr::DistributionTrace, ::AllSelection, retgrad) if !has_output_grad(dist(tr)) error("Distribution $(dist(tr)) does not compute gradient of logpdf with respect to value") end - output_grad, arg_grads... = logpdf_grad(dist(tr), tr.val, tr.args...) + grads = logpdf_grad(dist(tr), tr.val, tr.args...) + output_grad = grads[1] + arg_grads = grads[2:end] choice_values = ValueChoiceMap(tr.val) choice_grads = ValueChoiceMap(isnothing(retgrad) ? output_grad : output_grad .+ retgrad) return arg_grads, choice_values, choice_grads end @inline function Gen.choice_gradients(tr::DistributionTrace, ::EmptySelection, retgrad) - _, arg_grads... = logpdf_grad(dist(tr), tr.val, tr.args...) + arg_grads = logpdf_grad(dist(tr), tr.val, tr.args...)[2:end] choice_values = EmptyChoiceMap() choice_grads = EmptyChoiceMap() return arg_grads, choice_values, choice_grads @@ -154,7 +156,7 @@ function Gen.choice_gradients(tr::DistributionTrace, c::ComplementSelection, ret end @inline function Gen.accumulate_param_gradients!(tr::DistributionTrace, retgrad, scale_factor) - _, arg_grads... = logpdf_grad(dist(tr), tr.val, tr.args...) + arg_grads = logpdf_grad(dist(tr), tr.val, tr.args...)[2:end] return arg_grads end