Skip to content

Commit f48ab37

Browse files
committed
feat: clean up importance sampling code
1 parent 56551c1 commit f48ab37

File tree

1 file changed

+79
-31
lines changed

1 file changed

+79
-31
lines changed

src/gen/inference/importance.cljc

Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,87 @@
11
(ns gen.inference.importance
2-
(:require [clojure.math :as math]
2+
(:require [gen.choicemap :as choicemap]
33
[gen.distribution.kixi :as dist]
44
[gen.generative-function :as gf]))
55

66
;; This implementation comes from `fastmath.core`, ported here for cljc
77
;; purposes.
88

9-
(defn- logsumexp
10-
"log(exp(x1)+...+exp(xn))."
11-
^double [xs]
12-
(loop [[^double x & rst] xs
13-
r 0.0
14-
alpha ##-Inf]
15-
(if (<= x alpha)
16-
(let [nr (+ r (math/exp (- x alpha)))]
17-
(if-not (seq rst)
18-
(+ (math/log nr) alpha)
19-
(recur rst nr alpha)))
20-
(let [nr (inc (* r (math/exp (- alpha x))))]
21-
(if-not (seq rst)
22-
(+ (math/log nr) x)
23-
(recur rst nr (double x)))))))
9+
(defn ^:nodoc logsumexp
10+
"log(exp(x1)+...+exp(xn)). This version can handle infinities."
11+
([x] x)
12+
([x y]
13+
(if (<= y x)
14+
(+ x (Math/log (inc (Math/exp (- y x)))))
15+
(+ y (Math/log (inc (Math/exp (- x y)))))))
16+
([x y & more]
17+
(reduce logsumexp (logsumexp x y) more)))
2418

25-
(defn resampling [gf args observations n-samples]
26-
;; https://github.com/probcomp/Gen.jl/blob/master/src/inference/importance.jl#L77...L95
27-
(let [result (gf/generate gf args observations)
28-
model-trace (volatile! (:trace result))
29-
log-total-weight (volatile! (:weight result))]
30-
(dotimes [_ (dec n-samples)]
31-
(let [candidate (gf/generate gf args observations)
32-
candidate-model-trace (:trace candidate)
33-
log-weight (:weight candidate)]
34-
(vswap! log-total-weight #(logsumexp [log-weight %]))
35-
(when (dist/bernoulli (math/exp (- log-weight @log-total-weight)))
36-
(vreset! model-trace candidate-model-trace))))
37-
(let [log-ml-estimate (- @log-total-weight (math/log n-samples))]
38-
{:trace @model-trace
39-
:weight log-ml-estimate})))
19+
(defn- proposal-fn
20+
([gen-fn args constraints]
21+
(if (empty? constraints)
22+
(fn []
23+
(gf/propose gen-fn args))
24+
(let [constraints (choicemap/choicemap constraints)]
25+
(fn []
26+
(choicemap/merge
27+
constraints
28+
(gf/propose gen-fn args)))))))
29+
30+
(defn sampling
31+
([gf args {:keys [observations n-samples proposal proposal-args]
32+
:or {proposal-args []
33+
observations choicemap/EMPTY}}]
34+
;; TODO do this so it returns a lazy-seq of samples.
35+
)
36+
([gf args observations n-samples]
37+
(sampling gf args {:observations observations
38+
:n-samples n-samples}))
39+
([gf args observations proposal proposal-args n-samples]
40+
(sampling gf args {:observations observations
41+
:n-samples n-samples
42+
:proposal proposal
43+
:proposal-args proposal-args})))
44+
45+
(defn resampling
46+
"https://github.com/probcomp/Gen.jl/blob/master/src/inference/importance.jl#L77...L95
47+
48+
Run sampling importance resampling, returning a single trace.
49+
50+
TODO do we actually want a sequence of traces encountered, so that we can
51+
visualize the path?"
52+
([gf args {:keys [observations n-samples proposal proposal-args]
53+
:or {proposal-args []
54+
observations choicemap/EMPTY
55+
n-samples 10}}]
56+
(let [constraint-fn (if proposal
57+
(proposal-fn proposal proposal-args observations)
58+
(fn [] {:choices observations
59+
:weight 0.0}))
60+
{proposal-choices :choices
61+
proposal-weight :weight} (constraint-fn)
62+
result (gf/generate gf args proposal-choices)]
63+
(loop [i (dec n-samples)
64+
model-trace (:trace result)
65+
log-total-weight (- (:weight result) proposal-weight)]
66+
(if (zero? i)
67+
(let [log-ml-estimate (- log-total-weight (Math/log n-samples))]
68+
{:trace model-trace
69+
:weight log-ml-estimate})
70+
(let [{proposal-choices :choices
71+
proposal-weight :weight} (constraint-fn)
72+
candidate (gf/generate gf args proposal-choices)
73+
candidate-trace (:trace candidate)
74+
log-weight (- (:weight candidate) proposal-weight)
75+
new-weight (logsumexp log-total-weight log-weight)
76+
new-trace (if (dist/bernoulli (Math/exp (- log-weight new-weight)))
77+
candidate-trace
78+
model-trace)]
79+
(recur (unchecked-dec-int i) new-trace new-weight))))))
80+
([gf args observations n-samples]
81+
(resampling gf args {:observations observations
82+
:n-samples n-samples}))
83+
([gf args observations proposal proposal-args n-samples]
84+
(resampling gf args {:observations observations
85+
:n-samples n-samples
86+
:proposal proposal
87+
:proposal-args proposal-args})))

0 commit comments

Comments
 (0)