Skip to content

Commit 9d7635e

Browse files
authored
Merge pull request #11 from psg-mit/master
SSI improvements
2 parents 4c3d072 + dcdf2be commit 9d7635e

File tree

13 files changed

+1157
-157
lines changed

13 files changed

+1157
-157
lines changed

probzelus/inference/distribution.ml

Lines changed: 284 additions & 13 deletions
Large diffs are not rendered by default.

probzelus/inference/distribution.zli

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,24 @@ val bernoulli_mean : 'a -> 'a
6666
val bernoulli_variance : float -> float
6767
val bernoulli : float -> bool t
6868

69+
val binomial_draw : int -> float -> int
70+
val binomial_score : int -> float -> int -> float
71+
val binomial_mean : int -> float -> float
72+
val binomial_variance : int -> float -> float
73+
val binomial : int * float -> int t
74+
75+
val negative_binomial_draw : int -> float -> int
76+
val negative_binomial_score : int -> float -> int -> float
77+
val negative_binomial_mean : int -> float -> float
78+
val negative_binomial_variance : int -> float -> float
79+
val negative_binomial : int * float -> int t
80+
81+
val beta_binomial_draw : int -> float -> float -> int
82+
val beta_binomial_score : int -> float -> float -> int -> float
83+
val beta_binomial_mean : int -> float -> float -> float
84+
val beta_binomial_variance : int -> float -> float -> float
85+
val beta_binomial : int * float * float -> int t
86+
6987
val gaussian_draw : float -> float -> float
7088
val gaussian_score : float -> float -> float -> float
7189
val gaussian_mean : 'a -> 'b -> 'a
@@ -114,12 +132,24 @@ val exponential_mean : float -> float
114132
val exponential_variance : float -> float
115133
val exponential : float -> float t
116134

135+
val gamma_draw : float -> float -> float
136+
val gamma_score : float -> float -> float -> float
137+
val gamma_mean : float -> float -> float
138+
val gamma_variance : float -> float -> float
139+
val gamma : float * float -> float t
140+
117141
val poisson_draw : float -> int
118142
val poisson_score : float -> float -> float
119143
val poisson_mean : float -> float
120144
val poisson_variance : float -> float
121145
val poisson : float -> int t
122146

147+
val student_t_draw : float -> float -> float -> float
148+
val student_t_score : float -> float -> float -> float
149+
val student_t_mean : float -> float -> float -> float
150+
val student_t_variance : float -> float -> float -> float
151+
val student_t : float * float * float -> float t
152+
123153
val alias_method_unsafe : 'a array -> float array -> 'a t
124154
val alias_method_list : ('a * float) list -> 'a t
125155
val alias_method : 'a array -> float array -> 'a t

probzelus/inference/ds_naive_graph.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ let get_distr_kind : type a b.
210210
| DSnaive_Initialized (_, CBernBern _) -> KBernoulli
211211
| DSnaive_Marginalized (Dist_bernoulli _, _) -> KBernoulli
212212
| DSnaive_Marginalized (Dist_beta _, _) -> KBeta
213+
| DSnaive_Marginalized (Dist_binomial _, _) -> KOthers
214+
| DSnaive_Marginalized (Dist_beta_binomial _, _) -> KOthers
215+
| DSnaive_Marginalized (Dist_negative_binomial _, _) -> KOthers
213216
| DSnaive_Marginalized (( Dist_sampler _
214217
| Dist_support _), _) -> KOthers
215218
| DSnaive_Marginalized (Dist_sampler_float _, _) -> KOthers
@@ -220,7 +223,9 @@ let get_distr_kind : type a b.
220223
| DSnaive_Marginalized (Dist_uniform_int _, _) -> KOthers
221224
| DSnaive_Marginalized (Dist_uniform_float _, _) -> KOthers
222225
| DSnaive_Marginalized (Dist_exponential _, _) -> KOthers
226+
| DSnaive_Marginalized (Dist_gamma _, _) -> KOthers
223227
| DSnaive_Marginalized (Dist_poisson _, _) -> KOthers
228+
| DSnaive_Marginalized (Dist_student_t _, _) -> KOthers
224229
| DSnaive_Marginalized (Dist_lognormal _, _) -> KOthers
225230
| DSnaive_Marginalized (Dist_add _, _) -> KOthers
226231
| DSnaive_Marginalized (Dist_mult _, _) -> KOthers

probzelus/inference/ds_streaming_graph.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ module Make(Distribution: DISTRIBUTION) = struct
165165
| DSgraph_Initialized (_, CBernoulli) -> KBernoulli
166166
| DSgraph_Initialized (_, CBernBern _) -> KBernoulli
167167
| DSgraph_Marginalized (Dist_bernoulli _, _) -> KBernoulli
168+
| DSgraph_Marginalized (Dist_binomial _, _) -> KOthers
169+
| DSgraph_Marginalized (Dist_beta_binomial _, _) -> KOthers
170+
| DSgraph_Marginalized (Dist_negative_binomial _, _) -> KOthers
168171
| DSgraph_Marginalized (Dist_beta _, _) -> KBeta
169172
| DSgraph_Marginalized (( Dist_sampler _
170173
| Dist_support _), _) -> KOthers
@@ -176,7 +179,9 @@ module Make(Distribution: DISTRIBUTION) = struct
176179
| DSgraph_Marginalized (Dist_uniform_int _, _) -> KOthers
177180
| DSgraph_Marginalized (Dist_uniform_float _, _) -> KOthers
178181
| DSgraph_Marginalized (Dist_exponential _, _) -> KOthers
182+
| DSgraph_Marginalized (Dist_gamma _, _) -> KOthers
179183
| DSgraph_Marginalized (Dist_poisson _, _) -> KOthers
184+
| DSgraph_Marginalized (Dist_student_t _, _) -> KOthers
180185
| DSgraph_Marginalized (Dist_lognormal _, _) -> KOthers
181186
| DSgraph_Marginalized (Dist_add _, _) -> KOthers
182187
| DSgraph_Marginalized (Dist_mult _, _) -> KOthers

probzelus/inference/infer_semi_symbolic.ml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,20 @@ type 'a expr = 'a Semi_symbolic.expr
77
let const = Semi_symbolic.const
88
let add (a, b) = Semi_symbolic.add a b
99
let ( +~ ) = Semi_symbolic.add
10+
let subtract (a, b) =
11+
Semi_symbolic.add a (Semi_symbolic.mul (Semi_symbolic.const (-1.)) b)
12+
let ( -~ ) = (fun a b -> subtract (a, b))
1013
let mult (a, b) = Semi_symbolic.mul a b
1114
let ( *~ ) = Semi_symbolic.mul
15+
let div (a, b) = Semi_symbolic.div a b
16+
let ( /~ ) = Semi_symbolic.div
17+
let expp = Semi_symbolic.exp
1218
let pair (a, b) = Semi_symbolic.pair a b
1319
let array = Semi_symbolic.array
1420
let matrix = Semi_symbolic.matrix
1521
let lst = Semi_symbolic.lst
1622
let ite = Semi_symbolic.ite
23+
let lt (a, b) = Semi_symbolic.lt a b
1724

1825
let mat_add (a, b) = Semi_symbolic.mat_add a b
1926
let ( +@~) = Semi_symbolic.mat_add
@@ -36,6 +43,15 @@ let gaussian (mu, var) = Semi_symbolic.gaussian mu (Semi_symbolic.const var)
3643
let beta (a, b) =
3744
Semi_symbolic.beta (Semi_symbolic.const a) (Semi_symbolic.const b)
3845
let bernoulli p = Semi_symbolic.bernoulli p
46+
let binomial (n, p) = Semi_symbolic.binomial (Semi_symbolic.const n) p
47+
let beta_binomial (n, a, b) =
48+
Semi_symbolic.beta_binomial (Semi_symbolic.const n) a b
49+
let negative_binomial (n, p) = Semi_symbolic.negative_binomial (Semi_symbolic.const n) p
50+
let exponential lambda = Semi_symbolic.exponential lambda
51+
let gamma (a, b) = Semi_symbolic.gamma a b
52+
let poisson lambda = Semi_symbolic.poisson lambda
53+
let student_t (mu, tau2, nu) = Semi_symbolic.student_t mu tau2 nu
54+
let uniform_int (a, b) = Semi_symbolic.categorical ~lower:a ~upper:b (fun _ -> 1./.(float_of_int (b-a+1)))
3955
let mv_gaussian (mu, var) = Semi_symbolic.mv_gaussian mu (Semi_symbolic.const var)
4056
let mv_gaussian_curried var mu = mv_gaussian (mu, var)
4157

@@ -85,7 +101,10 @@ module Convert_fn_distr : Semi_symbolic.Conversion_fn with type 'a t = 'a Types.
85101
let const v = Dist_support [v, 1.]
86102
let add d1 d2 = Dist_add(d1, d2)
87103
let mul d1 d2 = Dist_mult(d1, d2)
104+
let div _ _ = assert false
105+
let exp _ = assert false
88106
let eq _ _ = assert false (* TODO: what to do here? *)
107+
let lt _ _ = assert false (* TODO: what to do here? *)
89108
let pair d1 d2 = Dist_pair(d1, d2)
90109
let array d = Dist_array d
91110
let lst l = Dist_list l
@@ -103,8 +122,16 @@ module Convert_fn_distr : Semi_symbolic.Conversion_fn with type 'a t = 'a Types.
103122
let gaussian mu var = Dist_gaussian (mu, var)
104123
let beta a b = Dist_beta(a, b)
105124
let bernoulli p = Dist_bernoulli p
125+
let binomial n p = Dist_binomial (n, p)
126+
let beta_binomial n a b = Dist_beta_binomial (n, a, b)
127+
let negative_binomial n p = Dist_negative_binomial (n, p)
128+
let exponential lambda = Dist_exponential lambda
129+
let gamma a b = Dist_gamma (a, b)
130+
let poisson lambda = Dist_poisson lambda
131+
let student_t mu tau2 nu = Dist_student_t (mu, tau2, nu)
106132
let delta x = Distribution.dirac x
107133
let mv_gaussian mu var = Dist_mv_gaussian (mu, var, None)
134+
let mixture l = Dist_mixture l
108135
let sampler draw score = Dist_sampler (draw, score)
109136
let categorical ~lower ~upper _ =
110137
ignore (lower, upper);

probzelus/inference/infer_semi_symbolic.zli

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,21 @@ type 'a expr
2222
val const : 'a -> 'a expr
2323
val add : float expr * float expr -> float expr
2424
val ( +~ ) : float expr -> float expr -> float expr
25+
val subtract : float expr * float expr -> float expr
26+
val ( -~ ) : float expr -> float expr -> float expr
2527
val mult : float expr * float expr -> float expr
2628
val ( *~ ) : float expr -> float expr -> float expr
29+
val div : float expr * float expr -> float expr
30+
val ( /~ ) : float expr -> float expr -> float expr
31+
val expp : float expr -> float expr
2732
(* val app : ('a -> 'b) expr * 'a expr -> 'b expr
2833
val ( @@~ ) : ('a -> 'b) expr -> 'a expr -> 'b expr *)
2934
val pair : 'a expr * 'b expr -> ('a * 'b) expr
3035
val array : 'a expr array -> 'a array expr
3136
val lst : 'a expr list -> 'a list expr
3237
val ite : bool expr -> 'a expr -> 'a expr -> 'a expr
3338
val matrix : 'a expr array array -> 'a array array expr
39+
val lt : 'a expr * 'a expr -> bool expr
3440

3541
val mat_add : Mat.mat expr * Mat.mat expr -> Mat.mat expr
3642
val ( +@~ ) : Mat.mat expr -> Mat.mat expr -> Mat.mat expr
@@ -48,6 +54,12 @@ val of_distribution : 'a Distribution.t -> 'a ds_distribution
4854
val gaussian : float expr * float -> float ds_distribution
4955
val beta : float * float -> float ds_distribution
5056
val bernoulli : float expr -> bool ds_distribution
57+
val binomial : int * float expr -> int ds_distribution
58+
val beta_binomial : int * float expr * float expr -> int ds_distribution
59+
val negative_binomial : int * float expr -> int ds_distribution
60+
val exponential : float expr -> float ds_distribution
61+
val gamma : float expr * float expr -> float ds_distribution
62+
val poisson : float expr -> int ds_distribution
5163
val mv_gaussian : Mat.mat expr * Mat.mat -> Mat.mat ds_distribution
5264
val mv_gaussian_curried : Mat.mat -> Mat.mat expr -> Mat.mat ds_distribution
5365

probzelus/inference/types.ml

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,15 @@ type _ distr =
4747
| Dist_lognormal : float * float -> float distr
4848
| Dist_beta : float * float -> float distr
4949
| Dist_bernoulli : float -> bool distr
50+
| Dist_binomial : int * float -> int distr
51+
| Dist_beta_binomial : int * float * float -> int distr
52+
| Dist_negative_binomial : int * float -> int distr
5053
| Dist_uniform_int : int * int -> int distr
5154
| Dist_uniform_float : float * float -> float distr
5255
| Dist_exponential : float -> float distr
56+
| Dist_gamma : float * float -> float distr
5357
| Dist_poisson : float -> int distr
58+
| Dist_student_t: float * float * float -> float distr
5459
| Dist_add : float distr * float distr -> float distr
5560
| Dist_mult : float distr * float distr -> float distr
5661
| Dist_app : ('a -> 'b) distr * 'a distr -> 'b distr
@@ -160,14 +165,29 @@ module type DISTRIBUTION = sig
160165
val print_t : ('a -> string) -> 'a t -> unit
161166

162167
val sampler : (unit -> 'a) * ('a -> float) -> 'a t
163-
val gamma : float -> float
168+
val gamma_f : float -> float
164169
val log_gamma : float -> float
165170
val dirac : 'a -> 'a t
166171
val bernoulli_draw : float -> bool
167172
val bernoulli_score : float -> bool -> float
168173
val bernoulli_mean : 'a -> 'a
169174
val bernoulli_variance : float -> float
170175
val bernoulli : float -> bool t
176+
val binomial_draw : int -> float -> int
177+
val binomial_score : int -> float -> int -> float
178+
val binomial_mean : int -> float -> float
179+
val binomial_variance : int -> float -> float
180+
val binomial : int * float -> int t
181+
val negative_binomial_draw : int -> float -> int
182+
val negative_binomial_score : int -> float -> int -> float
183+
val negative_binomial_mean : int -> float -> float
184+
val negative_binomial_variance : int -> float -> float
185+
val negative_binomial : int * float -> int t
186+
val beta_binomial_draw : int -> float -> float -> int
187+
val beta_binomial_score : int -> float -> float -> int -> float
188+
val beta_binomial_mean : int -> float -> float -> float
189+
val beta_binomial_variance : int -> float -> float -> float
190+
val beta_binomial : int * float * float -> int t
171191
val gaussian_draw : float -> float -> float
172192
val gaussian_score : float -> float -> float -> float
173193
val gaussian_mean : 'a -> 'b -> 'a
@@ -213,11 +233,21 @@ module type DISTRIBUTION = sig
213233
val exponential_mean : float -> float
214234
val exponential_variance : float -> float
215235
val exponential : float -> float t
236+
val gamma_draw : float -> float -> float
237+
val gamma_score : float -> float -> float -> float
238+
val gamma_mean : float -> float -> float
239+
val gamma_variance : float -> float -> float
240+
val gamma : float * float -> float t
216241
val poisson_draw : float -> int
217242
val poisson_score : float -> int -> float
218243
val poisson_mean : 'a -> 'a
219244
val poisson_variance : 'a -> 'a
220245
val poisson : float -> int t
246+
val student_t_draw : float -> float -> float -> float
247+
val student_t_score : float -> float -> float -> float -> float
248+
val student_t_mean : float -> float -> float -> float
249+
val student_t_variance : float -> float -> float -> float
250+
val student_t : float * float * float -> float t
221251
val alias_method_unsafe : 'a array -> float array -> 'a t
222252
val alias_method_list : ('a * float) list -> 'a t
223253
val alias_method : 'a array -> float array -> 'a t

0 commit comments

Comments
 (0)