Skip to content

Commit ab64bd9

Browse files
committed
feat: clean up importance sampling code
1 parent 56551c1 commit ab64bd9

File tree

1 file changed

+29
-32
lines changed

1 file changed

+29
-32
lines changed

src/gen/inference/importance.cljc

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,36 @@
11
(ns gen.inference.importance
2-
(:require [clojure.math :as math]
3-
[gen.distribution.kixi :as dist]
2+
(:require [gen.distribution.kixi :as dist]
43
[gen.generative-function :as gf]))
54

65
;; This implementation comes from `fastmath.core`, ported here for cljc
76
;; purposes.
87

9-
(defn- logsumexp
10-
"log(exp(x1)+...+exp(xn))."
11-
^double [xs]
12-
(loop [[^double x & rst] xs
13-
r 0.0
14-
alpha ##-Inf]
15-
(if (<= x alpha)
16-
(let [nr (+ r (math/exp (- x alpha)))]
17-
(if-not (seq rst)
18-
(+ (math/log nr) alpha)
19-
(recur rst nr alpha)))
20-
(let [nr (inc (* r (math/exp (- alpha x))))]
21-
(if-not (seq rst)
22-
(+ (math/log nr) x)
23-
(recur rst nr (double x)))))))
8+
(defn ^:nodoc logsumexp
9+
"log(exp(x1)+...+exp(xn)). This version can handle infinities."
10+
([x] x)
11+
([x y]
12+
(if (<= y x)
13+
(+ x (Math/log (inc (Math/exp (- y x)))))
14+
(+ y (Math/log (inc (Math/exp (- x y)))))))
15+
([x y & more]
16+
(reduce logsumexp (logsumexp x y) more)))
2417

25-
(defn resampling [gf args observations n-samples]
26-
;; https://github.com/probcomp/Gen.jl/blob/master/src/inference/importance.jl#L77...L95
27-
(let [result (gf/generate gf args observations)
28-
model-trace (volatile! (:trace result))
29-
log-total-weight (volatile! (:weight result))]
30-
(dotimes [_ (dec n-samples)]
31-
(let [candidate (gf/generate gf args observations)
32-
candidate-model-trace (:trace candidate)
33-
log-weight (:weight candidate)]
34-
(vswap! log-total-weight #(logsumexp [log-weight %]))
35-
(when (dist/bernoulli (math/exp (- log-weight @log-total-weight)))
36-
(vreset! model-trace candidate-model-trace))))
37-
(let [log-ml-estimate (- @log-total-weight (math/log n-samples))]
38-
{:trace @model-trace
39-
:weight log-ml-estimate})))
18+
(defn resampling
19+
"https://github.com/probcomp/Gen.jl/blob/master/src/inference/importance.jl#L77...L95"
20+
[gf args observations n-samples]
21+
(let [result (gf/generate gf args observations)]
22+
(loop [i (dec n-samples)
23+
model-trace (:trace result)
24+
log-total-weight (:weight result)]
25+
(if (zero? i)
26+
(let [log-ml-estimate (- log-total-weight (Math/log n-samples))]
27+
{:trace model-trace
28+
:weight log-ml-estimate})
29+
(let [candidate (gf/generate gf args observations)
30+
candidate-trace (:trace candidate)
31+
log-weight (:weight candidate)
32+
new-weight (logsumexp log-total-weight log-weight)
33+
new-trace (if (dist/bernoulli (Math/exp (- log-weight new-weight)))
34+
candidate-trace
35+
model-trace)]
36+
(recur (unchecked-dec-int i) new-trace new-weight))))))

0 commit comments

Comments
 (0)