@@ -62,8 +62,8 @@ y <- rnorm(500, 10, 2)
6262
6363As with other estimation routines provided in R, we need to specify this
6464as a function which takes a vector of parameters as its first argument
65- and returns a single scalar value (the log-likelihood), as well as
66- initial values for the parameters:
65+ and returns a single scalar value (the unnormalized target log density),
66+ as well as initial values for the parameters:
6767
6868``` r
6969loglik_fun <- function (v , x ) {
@@ -107,14 +107,14 @@ iterations
107107``` r
108108unlist(fit @ timing )
109109# > warmup sampling
110- # > 0.527 0.490
110+ # > 0.720 0.707
111111summary(fit )
112112# > # A tibble: 3 × 10
113113# > variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
114114# > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
115- # > 1 lp__ -1.08e3 -1.08e3 1.03 0.749 -1.08e3 -1.08e3 1.00 507 . 672 .
116- # > 2 pars[1] 1.01e1 1.01e1 0.0940 0.0948 9.97e0 1.03e1 1.00 895 . 671 .
117- # > 3 pars[2] 2.11e0 2.10e0 0.0686 0.0670 2.00e0 2.22e0 1.00 860 . 696 .
115+ # > 1 lp__ -1.05e3 -1.05e3 0.973 0.788 -1.05e3 -1.05e3 1.01 521 . 720 .
116+ # > 2 pars[1] 9.96e0 9.97e0 0.0912 0.0911 9.81e0 1.01e1 1.00 943 . 712 .
117+ # > 3 pars[2] 1.97e0 1.96e0 0.0637 0.0674 1.87e0 2.08e0 1.00 878 . 615 .
118118```
119119
120120Estimation time can be improved further by providing a gradient
@@ -134,14 +134,14 @@ Which shows that the estimation time was dramatically improved, now
134134``` r
135135unlist(fit_grad @ timing )
136136# > warmup sampling
137- # > 0.111 0.087
137+ # > 0.103 0.093
138138summary(fit_grad )
139139# > # A tibble: 3 × 10
140140# > variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
141141# > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
142- # > 1 lp__ -1.08e3 -1.08e3 1.02 0.741 -1.08e3 -1.08e3 1.00 572 . 712 .
143- # > 2 pars[1] 1.01e1 1.01e1 0.0928 0.0943 9.97e0 1.03e1 1.00 950 . 623 .
144- # > 3 pars[2] 2.10e0 2.10e0 0.0691 0.0696 1.99e0 2.22e0 1.00 725 . 613 .
142+ # > 1 lp__ -1.05e3 -1.05e3 0.952 0.763 -1.05e3 -1.05e3 1.01 500 . 675 .
143+ # > 2 pars[1] 9.97e0 9.97e0 0.0905 0.0954 9.82e0 1.01e1 1.000 830 . 531 .
144+ # > 3 pars[2] 1.96e0 1.96e0 0.0619 0.0616 1.87e0 2.07e0 1.00 1047 . 640 .
145145```
146146
147147### Optimization
@@ -158,11 +158,11 @@ opt_grad <- stan_optimize(loglik_fun, inits, additional_args = list(y),
158158
159159``` r
160160summary(opt_fd )
161- # > lp__ pars[1] pars[2]
162- # > 1 -1079.84 10.1221 2.09743
161+ # > lp__ pars[1] pars[2]
162+ # > 1 -1046.049 9.9691 1.96036
163163summary(opt_grad )
164- # > lp__ pars[1] pars[2]
165- # > 1 -1079.84 10.1221 2.09743
164+ # > lp__ pars[1] pars[2]
165+ # > 1 -1046.049 9.9691 1.96036
166166```
167167
168168### Laplace Approximation
@@ -191,28 +191,28 @@ summary(lapl_num)
191191# > # A tibble: 4 × 10
192192# > variable mean median sd mad q5 q95 rhat ess_bulk
193193# > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
194- # > 1 log_p__ -1082 . -1082 . 2.39 2.18 -1087 . -1080 . 0.999 989 .
195- # > 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047 .
196- # > 3 pars[1] 10.0 10.0 0.0899 0.0867 9.85 10.1 1.00 933 .
197- # > 4 pars[2] 2.00 2.00 0.0626 0.0635 1.90 2.11 1.00 1051 .
194+ # > 1 log_p__ -1477 . -1475 . 55.3 56.0 -1572 . -1389 . 1.00 986 .
195+ # > 2 log_q__ -1.01 -0.695 1.01 0.743 -3.03 -0.0443 1.00 913 .
196+ # > 3 pars[1] 10.0 10.00 0.335 0.343 9.47 10.5 0.999 831 .
197+ # > 4 pars[2] 7.45 7.39 0.897 0.893 6.10 9.08 1.00 987 .
198198# > # ℹ 1 more variable: ess_tail <dbl>
199199summary(lapl_opt )
200200# > # A tibble: 4 × 10
201201# > variable mean median sd mad q5 q95 rhat ess_bulk
202202# > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
203- # > 1 log_p__ -1080 . -1080 . 1.06 0.712 -1082 . -1079 . 0.999 1044 .
204- # > 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047 .
205- # > 3 pars[1] 10.1 10.1 0.0940 0.0897 9.96 10.3 1.00 932 .
206- # > 4 pars[2] 2.10 2 .10 0.0688 0.0697 1.99 2.21 1.00 1051 .
203+ # > 1 log_p__ -1458 . -1457 . 52.8 53.5 -1549 . -1374 . 1.00 986 .
204+ # > 2 log_q__ -1.01 -0.695 1.01 0.743 -3.03 -0.0443 1.00 913 .
205+ # > 3 pars[1] 9.97 9.97 0.321 0.329 9.46 10.5 0.999 830 .
206+ # > 4 pars[2] 7.16 7 .10 0.827 0.824 5.91 8.66 1.00 987 .
207207# > # ℹ 1 more variable: ess_tail <dbl>
208208summary(lapl_est )
209209# > # A tibble: 4 × 10
210210# > variable mean median sd mad q5 q95 rhat ess_bulk
211211# > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
212- # > 1 log_p__ -1080 . -1080 . 1.06 0.712 -1082 . -1079 . 0.999 1044 .
213- # > 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047 .
214- # > 3 pars[1] 10.1 10.1 0.0940 0.0897 9.96 10.3 1.00 932 .
215- # > 4 pars[2] 2.10 2 .10 0.0688 0.0697 1.99 2.21 1.00 1051 .
212+ # > 1 log_p__ -1458 . -1457 . 52.8 53.5 -1549 . -1374 . 1.00 986 .
213+ # > 2 log_q__ -1.01 -0.695 1.01 0.743 -3.03 -0.0443 1.00 913 .
214+ # > 3 pars[1] 9.97 9.97 0.321 0.329 9.46 10.5 0.999 830 .
215+ # > 4 pars[2] 7.16 7 .10 0.827 0.824 5.91 8.66 1.00 987 .
216216# > # ℹ 1 more variable: ess_tail <dbl>
217217```
218218
@@ -231,23 +231,23 @@ var_grad <- stan_variational(loglik_fun, inits, additional_args = list(y),
231231``` r
232232summary(var_fd )
233233# > # A tibble: 5 × 10
234- # > variable mean median sd mad q5 q95 rhat ess_bulk
235- # > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
236- # > 1 lp__ 0 0 0 0 0 0 NA NA
237- # > 2 log_p__ -1081 . -1080 . 1.33 0.986 -1083 . -1079 . 0.999 997 .
238- # > 3 log_g__ -1.03 -0.714 1.03 0.731 -3.29 -0.0486 1.00 959 .
239- # > 4 pars[1] 10.2 10.2 0.0869 0.0898 10.1 10.4 1.00 1012 .
240- # > 5 pars[2] 2.09 2.09 0.0650 0.0639 1.99 2.20 1.00 850 .
234+ # > variable mean median sd mad q5 q95 rhat ess_bulk
235+ # > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
236+ # > 1 lp__ 0 0 0 0 0 0 NA NA
237+ # > 2 log_p__ -1047 . -1046 . 1.25 0.975 -1049 . -1045 . 1.00 1017 .
238+ # > 3 log_g__ -0.978 -0.660 0.966 0.678 -2.84 -0.0566 1.00 1054 .
239+ # > 4 pars[1] 10.0 10.0 0.0847 0.0877 9.88 10.2 0.999 1025 .
240+ # > 5 pars[2] 1.92 1.92 0.0528 0.0523 1.83 2.01 1.00 1047 .
241241# > # ℹ 1 more variable: ess_tail <dbl>
242242summary(var_grad )
243243# > # A tibble: 5 × 10
244- # > variable mean median sd mad q5 q95 rhat ess_bulk
245- # > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
246- # > 1 lp__ 0 0 0 0 0 0 NA NA
247- # > 2 log_p__ -1081 . -1080 . 1.33 0.986 -1083 . -1079 . 0.999 997 .
248- # > 3 log_g__ -1.03 -0.714 1.03 0.731 -3.29 -0.0486 1.00 959 .
249- # > 4 pars[1] 10.2 10.2 0.0869 0.0898 10.1 10.4 1.00 1012 .
250- # > 5 pars[2] 2.09 2.09 0.0650 0.0639 1.99 2.20 1.00 850 .
244+ # > variable mean median sd mad q5 q95 rhat ess_bulk
245+ # > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
246+ # > 1 lp__ 0 0 0 0 0 0 NA NA
247+ # > 2 log_p__ -1047 . -1046 . 1.25 0.975 -1049 . -1045 . 1.00 1017 .
248+ # > 3 log_g__ -0.978 -0.660 0.966 0.678 -2.84 -0.0566 1.00 1054 .
249+ # > 4 pars[1] 10.0 10.0 0.0847 0.0877 9.88 10.2 0.999 1025 .
250+ # > 5 pars[2] 1.92 1.92 0.0528 0.0523 1.83 2.01 1.00 1047 .
251251# > # ℹ 1 more variable: ess_tail <dbl>
252252```
253253
@@ -265,19 +265,23 @@ path_grad <- stan_pathfinder(loglik_fun, inits, additional_args = list(y),
265265
266266``` r
267267summary(path_fd )
268- # > # A tibble: 4 × 10
269- # > variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
270- # > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
271- # > 1 lp_appr… 2.96e0 3.27e0 0.998 0.716 1.03e0 3.90e0 1.00 949. 909.
272- # > 2 lp__ -1.08e3 -1.08e3 1.04 0.726 -1.08e3 -1.08e3 1.00 946. 820.
273- # > 3 pars[1] 1.01e1 1.01e1 0.0955 0.0920 9.96e0 1.03e1 0.999 1004. 800.
274- # > 4 pars[2] 2.10e0 2.11e0 0.0668 0.0695 1.99e0 2.21e0 1.00 998. 907.
268+ # > # A tibble: 5 × 10
269+ # > variable mean median sd mad q5 q95 rhat ess_bulk
270+ # > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
271+ # > 1 lp_approx__ 3.04 3.45 1.19 0.704 0.609 4.07 1.00 652.
272+ # > 2 lp__ -1046. -1046. 1.09 0.661 -1049. -1045. 1.00 653.
273+ # > 3 path__ 2.51 3 1.10 1.48 1 4 2.65 1.20
274+ # > 4 pars[1] 9.97 9.96 0.0872 0.0835 9.82 10.1 1.000 803.
275+ # > 5 pars[2] 1.96 1.96 0.0633 0.0606 1.86 2.07 1.00 734.
276+ # > # ℹ 1 more variable: ess_tail <dbl>
275277summary(path_grad )
276- # > # A tibble: 4 × 10
277- # > variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
278- # > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
279- # > 1 lp_appr… 2.96e0 3.27e0 0.998 0.716 1.03e0 3.90e0 1.00 949. 909.
280- # > 2 lp__ -1.08e3 -1.08e3 1.04 0.726 -1.08e3 -1.08e3 1.00 946. 820.
281- # > 3 pars[1] 1.01e1 1.01e1 0.0955 0.0920 9.96e0 1.03e1 0.999 1004. 800.
282- # > 4 pars[2] 2.10e0 2.11e0 0.0668 0.0695 1.99e0 2.21e0 1.00 998. 907.
278+ # > # A tibble: 5 × 10
279+ # > variable mean median sd mad q5 q95 rhat ess_bulk
280+ # > <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
281+ # > 1 lp_approx__ 3.04 3.45 1.19 0.704 0.609 4.07 1.00 652.
282+ # > 2 lp__ -1046. -1046. 1.09 0.661 -1049. -1045. 1.00 653.
283+ # > 3 path__ 2.51 3 1.10 1.48 1 4 2.65 1.20
284+ # > 4 pars[1] 9.97 9.96 0.0872 0.0835 9.82 10.1 1.000 803.
285+ # > 5 pars[2] 1.96 1.96 0.0633 0.0606 1.86 2.07 1.00 734.
286+ # > # ℹ 1 more variable: ess_tail <dbl>
283287```
0 commit comments