Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 79 additions & 31 deletions src/gen/inference/importance.cljc
Original file line number Diff line number Diff line change
@@ -1,39 +1,87 @@
(ns gen.inference.importance
(:require [clojure.math :as math]
(:require [gen.choicemap :as choicemap]
[gen.distribution.kixi :as dist]
[gen.generative-function :as gf]))

;; This implementation comes from `fastmath.core`, ported here for cljc
;; purposes.

(defn- logsumexp
"log(exp(x1)+...+exp(xn))."
^double [xs]
(loop [[^double x & rst] xs
r 0.0
alpha ##-Inf]
(if (<= x alpha)
(let [nr (+ r (math/exp (- x alpha)))]
(if-not (seq rst)
(+ (math/log nr) alpha)
(recur rst nr alpha)))
(let [nr (inc (* r (math/exp (- alpha x))))]
(if-not (seq rst)
(+ (math/log nr) x)
(recur rst nr (double x)))))))
(defn ^:nodoc logsumexp
"log(exp(x1)+...+exp(xn)). This version can handle infinities."
([x] x)
([x y]
(if (<= y x)
(+ x (Math/log (inc (Math/exp (- y x)))))
(+ y (Math/log (inc (Math/exp (- x y)))))))
([x y & more]
(reduce logsumexp (logsumexp x y) more)))

(defn resampling [gf args observations n-samples]
;; https://github.com/probcomp/Gen.jl/blob/master/src/inference/importance.jl#L77...L95
(let [result (gf/generate gf args observations)
model-trace (volatile! (:trace result))
log-total-weight (volatile! (:weight result))]
(dotimes [_ (dec n-samples)]
(let [candidate (gf/generate gf args observations)
candidate-model-trace (:trace candidate)
log-weight (:weight candidate)]
(vswap! log-total-weight #(logsumexp [log-weight %]))
(when (dist/bernoulli (math/exp (- log-weight @log-total-weight)))
(vreset! model-trace candidate-model-trace))))
(let [log-ml-estimate (- @log-total-weight (math/log n-samples))]
{:trace @model-trace
:weight log-ml-estimate})))
(defn- proposal-fn
([gen-fn args constraints]
(if (empty? constraints)
(fn []
(gf/propose gen-fn args))
(let [constraints (choicemap/choicemap constraints)]
(fn []
(choicemap/merge
constraints
(gf/propose gen-fn args)))))))

(defn sampling
([gf args {:keys [observations n-samples proposal proposal-args]
:or {proposal-args []
observations choicemap/EMPTY}}]
;; TODO do this so it returns a lazy-seq of samples.
)
([gf args observations n-samples]
(sampling gf args {:observations observations
:n-samples n-samples}))
([gf args observations proposal proposal-args n-samples]
(sampling gf args {:observations observations
:n-samples n-samples
:proposal proposal
:proposal-args proposal-args})))

(defn resampling
"https://github.com/probcomp/Gen.jl/blob/master/src/inference/importance.jl#L77...L95

Run sampling importance resampling, returning a single trace.

TODO do we actually want a sequence of traces encountered, so that we can
visualize the path?"
([gf args {:keys [observations n-samples proposal proposal-args]
:or {proposal-args []
observations choicemap/EMPTY
n-samples 10}}]
(let [constraint-fn (if proposal
(proposal-fn proposal proposal-args observations)
(fn [] {:choices observations
:weight 0.0}))
{proposal-choices :choices
proposal-weight :weight} (constraint-fn)
result (gf/generate gf args proposal-choices)]
(loop [i (dec n-samples)
model-trace (:trace result)
log-total-weight (- (:weight result) proposal-weight)]
(if (zero? i)
(let [log-ml-estimate (- log-total-weight (Math/log n-samples))]
{:trace model-trace
:weight log-ml-estimate})
(let [{proposal-choices :choices
proposal-weight :weight} (constraint-fn)
candidate (gf/generate gf args proposal-choices)
candidate-trace (:trace candidate)
log-weight (- (:weight candidate) proposal-weight)
new-weight (logsumexp log-total-weight log-weight)
new-trace (if (dist/bernoulli (Math/exp (- log-weight new-weight)))
candidate-trace
model-trace)]
(recur (unchecked-dec-int i) new-trace new-weight))))))
([gf args observations n-samples]
(resampling gf args {:observations observations
:n-samples n-samples}))
([gf args observations proposal proposal-args n-samples]
(resampling gf args {:observations observations
:n-samples n-samples
:proposal proposal
:proposal-args proposal-args})))