Skip to content

Commit febf288

Browse files
committed
feat: add students t-distribution
1 parent b6d5a7e commit febf288

File tree

8 files changed

+115
-15
lines changed

8 files changed

+115
-15
lines changed

src/gen/distribution/commons_math.clj

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
EnumeratedIntegerDistribution
1212
GammaDistribution
1313
NormalDistribution
14+
TDistribution
1415
UniformIntegerDistribution
1516
UniformRealDistribution)
1617
(org.apache.commons.math3.random
@@ -36,6 +37,19 @@
3637
(sample [^AbstractIntegerDistribution obj]
3738
(.sample obj)))
3839

40+
;; Small wrapper around a kixi T-distribution to allow for location and scale
41+
;; parameters.
42+
43+
(defrecord LocationScaleT [^TDistribution t-dist location scale]
44+
d/LogPDF
45+
(logpdf [_ v]
46+
(- (.logDensity t-dist (/ (- v location) scale))
47+
(Math/log scale)))
48+
49+
d/Sample
50+
(sample [_]
51+
(+ location (* scale (.sample t-dist)))))
52+
3953
;; ## Primitive probability distributions
4054

4155
(defn ^:no-doc rng ^RandomGenerator []
@@ -62,6 +76,13 @@
6276
(defn gamma-distribution [^double shape ^double scale]
6377
(GammaDistribution. (rng) shape scale))
6478

79+
(defn student-t-distribution
80+
([^double nu]
81+
(TDistribution. (rng) nu))
82+
([nu location scale]
83+
(->LocationScaleT
84+
(student-t-distribution nu) location scale)))
85+
6586
(defn normal-distribution
6687
([] (normal-distribution 0.0 1.0))
6788
([^double mean ^double sd]
@@ -92,6 +113,9 @@
92113
(def gamma
93114
(d/->GenerativeFn gamma-distribution))
94115

116+
(def student-t
117+
(d/->GenerativeFn student-t-distribution))
118+
95119
(def normal
96120
(d/->GenerativeFn normal-distribution))
97121

src/gen/distribution/kixi.cljc

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#?(:clj
66
(:import (kixi.stats.distribution Bernoulli Cauchy
77
Exponential Beta
8-
Gamma Normal Uniform))))
8+
Gamma Normal Uniform T))))
99

1010
;; ## Kixi.stats protocol implementations
1111
;;
@@ -80,6 +80,26 @@
8080
(.-sd this)
8181
v)))
8282

83+
(extend-type #?(:clj T :cljs k/T)
84+
d/Sample
85+
(sample [this] (k/draw this))
86+
87+
d/LogPDF
88+
(logpdf [this v]
89+
(ll/student-t (.-dof this) 0 1 v)))
90+
91+
;; Small wrapper around a kixi T-distribution to allow for location and scale
92+
;; parameters.
93+
94+
(defrecord LocationScaleT [^T t-dist location scale]
95+
d/Sample
96+
(sample [_]
97+
(+ location (* scale (k/draw t-dist))))
98+
99+
d/LogPDF
100+
(logpdf [_ v]
101+
(ll/student-t (.-dof t-dist) location scale v)))
102+
83103
;; ## Primitive probability distributions
84104

85105
(defn bernoulli-distribution
@@ -110,6 +130,14 @@
110130
(defn gamma-distribution [shape scale]
111131
(k/gamma {:shape shape :scale scale}))
112132

133+
(defn student-t-distribution
134+
([nu]
135+
(k/t {:v nu}))
136+
([nu location scale]
137+
(->LocationScaleT (student-t-distribution nu)
138+
location
139+
scale)))
140+
113141
;; ## Primitive generative functions
114142

115143
(def bernoulli
@@ -132,3 +160,6 @@
132160

133161
(def gamma
134162
(d/->GenerativeFn gamma-distribution))
163+
164+
(def student-t
165+
(d/->GenerativeFn student-t-distribution))

src/gen/distribution/math/log_likelihood.cljc

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@
167167
(Math/log variance)
168168
(/ v-mu-sq variance)))))
169169

170-
171170
(defn uniform
172171
"Returns the log-likelihood of the continuous [uniform
173172
distribution](https://en.wikipedia.org/wiki/Continuous_uniform_distribution)
@@ -177,3 +176,26 @@
177176
(if (<= a v b)
178177
(- (Math/log (- b a)))
179178
##-Inf))
179+
180+
(defn student-t
181+
"Returns the log-likelihood of the [non-standardized Student's
182+
t-distribution](https://en.wikipedia.org/wiki/Student's_t-distribution#Location-scale_transformation)
183+
parametrized by `location`, `scale` and degrees-of-freedom `nu` at the value
184+
`v`.
185+
186+
This distribution is also known as the location-scale t-distribution.
187+
188+
The implementation follows the algorithm described on the
189+
distribution's [Wikipedia
190+
page](https://en.wikipedia.org/wiki/Student's_t-distribution#Location-scale_transformation)."
191+
([nu v] (student-t nu 0 1 v))
192+
([nu location scale v]
193+
(let [inc-nu (inc nu)
194+
half-inc-nu (* 0.5 inc-nu)
195+
normalized (/ (- v location) scale)
196+
norm**2 (* normalized normalized)]
197+
(- (log-gamma-fn half-inc-nu)
198+
(log-gamma-fn (* 0.5 nu))
199+
(Math/log scale)
200+
(* 0.5 (+ log-pi (Math/log nu)))
201+
(* half-inc-nu (Math/log (inc (/ norm**2 nu))))))))

test/gen/distribution/commons_math_test.clj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@
1818

1919
(deftest gamma-tests
2020
(dt/gamma-tests commons/gamma-distribution))
21+
22+
(deftest student-t-tests
23+
(dt/student-t-tests commons/student-t-distribution))

test/gen/distribution/kixi_test.cljc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,6 @@
2424

2525
(deftest gamma-tests
2626
(dt/gamma-tests kixi/gamma-distribution))
27+
28+
(deftest student-t-tests
29+
(dt/student-t-tests kixi/student-t-distribution))

test/gen/distribution/math/log_likelihood_test.cljc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,9 @@
44
[gen.distribution.math.log-likelihood :as ll]
55
[gen.distribution :as distribution]
66
[gen.distribution-test :as dt]
7-
[gen.test-check-util :refer [gen-double]]
7+
[gen.test-check-util :refer [gen-double within]]
88
[same.core :refer [ish? with-comparator]]))
99

10-
(defn within
11-
"Returns a function that tests whether two values are within `eps` of each
12-
other."
13-
[^double eps]
14-
(fn [^double x ^double y]
15-
(< (Math/abs (- x y)) eps)))
16-
1710
(defn factorial
1811
"Factorial implementation for testing."
1912
[n]
@@ -79,3 +72,6 @@
7972

8073
(deftest uniform-tests
8174
(dt/uniform-tests (->logpdf ll/uniform)))
75+
76+
(deftest student-t-tests
77+
(dt/student-t-tests (->logpdf ll/student-t)))

test/gen/distribution_test.cljc

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,27 @@
77
[gen.dynamic.choice-map :as choice-map]
88
[gen.generative-function :as gf]
99
[gen.trace :as trace]
10-
[gen.test-check-util :refer [gen-double]]
11-
[same.core :refer [ish? zeroish?]]))
10+
[gen.test-check-util :refer [gen-double within]]
11+
[same.core :refer [ish? zeroish? with-comparator]]))
1212

1313
(defn gamma-tests [->gamma]
1414
(testing "spot checks"
1515
(is (= -6.391804444241573 (dist/logpdf (->gamma 0.001 1) 0.4)))
1616
(is (= -393.0922447210179 (dist/logpdf (->gamma 1 0.001) 0.4)))))
1717

18+
(defn student-t-tests [->student-t]
19+
(testing "spot checks"
20+
(with-comparator (within 1e-12)
21+
(is (ish? -1.7347417805005154 (dist/logpdf (->student-t 2 2.1 2) 2)))
22+
(is (ish? -2.795309741614719 (dist/logpdf (->student-t 1 0.8 4) 3)))))
23+
24+
(checking "Student's T matches generalized logpdf"
25+
[v (gen-double -10 10)
26+
nu (gen/fmap inc gen/nat)]
27+
(is (= (dist/logpdf (->student-t nu 0 1) v)
28+
(dist/logpdf (->student-t nu) v))
29+
"these two paths should produce the same results")))
30+
1831
(defn beta-tests [->beta]
1932
(testing "spot checks"
2033
(is (= -5.992380837839856 (dist/logpdf (->beta 0.001 1) 0.4)))
@@ -161,9 +174,10 @@
161174
(dist/logpdf (->normal 0.0 sigma) (- v)))
162175
"Normal is symmetric about the mean")
163176

164-
(is (ish? (dist/logpdf (->normal mu sigma) v)
165-
(dist/logpdf (->normal (+ mu shift) sigma) (+ v shift)))
166-
"shifting by the mean is a symmetry"))
177+
(with-comparator (within 1e-12)
178+
(is (ish? (dist/logpdf (->normal mu sigma) v)
179+
(dist/logpdf (->normal (+ mu shift) sigma) (+ v shift)))
180+
"shifting by the mean is a symmetry")))
167181

168182
(testing "spot checks"
169183
(is (= -1.0439385332046727 (dist/logpdf (->normal 0 1) 0.5)))

test/gen/test_check_util.cljc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
"Utilities for using Gen distributions and types with test.check."
33
(:require [clojure.test.check.generators :as gen]))
44

5+
(defn within
6+
"Returns a function that tests whether two values are within `eps` of each
7+
other."
8+
[^double eps]
9+
(fn [^double x ^double y]
10+
(< (Math/abs (- x y)) eps)))
11+
512
(defn gen-double
613
"Returns a generator that produces numerical doubles between `min` and
714
`max` (inclusive)."

0 commit comments

Comments
 (0)