Skip to content

Commit b09f7c5

Browse files
committed
Change log-likelihood usage
1 parent 62ff132 commit b09f7c5

File tree

3 files changed

+62
-58
lines changed

3 files changed

+62
-58
lines changed

README.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ y <- rnorm(500, 10, 2)
7171

7272
As with other estimation routines provided in R, we need to specify this as a
7373
function which takes a vector of parameters as its first argument and returns a
74-
single scalar value (the log-likelihood), as well as initial values for the
74+
single scalar value (the unnormalized target log density), as well as initial values for the
7575
parameters:
7676

7777
```{r}

README.md

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ y <- rnorm(500, 10, 2)
6262

6363
As with other estimation routines provided in R, we need to specify this
6464
as a function which takes a vector of parameters as its first argument
65-
and returns a single scalar value (the log-likelihood), as well as
66-
initial values for the parameters:
65+
and returns a single scalar value (the unnormalized target log density),
66+
as well as initial values for the parameters:
6767

6868
``` r
6969
loglik_fun <- function(v, x) {
@@ -107,14 +107,14 @@ iterations
107107
``` r
108108
unlist(fit@timing)
109109
#> warmup sampling
110-
#> 0.527 0.490
110+
#> 0.720 0.707
111111
summary(fit)
112112
#> # A tibble: 3 × 10
113113
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
114114
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
115-
#> 1 lp__ -1.08e3 -1.08e3 1.03 0.749 -1.08e3 -1.08e3 1.00 507. 672.
116-
#> 2 pars[1] 1.01e1 1.01e1 0.0940 0.0948 9.97e0 1.03e1 1.00 895. 671.
117-
#> 3 pars[2] 2.11e0 2.10e0 0.0686 0.0670 2.00e0 2.22e0 1.00 860. 696.
115+
#> 1 lp__ -1.05e3 -1.05e3 0.973 0.788 -1.05e3 -1.05e3 1.01 521. 720.
116+
#> 2 pars[1] 9.96e0 9.97e0 0.0912 0.0911 9.81e0 1.01e1 1.00 943. 712.
117+
#> 3 pars[2] 1.97e0 1.96e0 0.0637 0.0674 1.87e0 2.08e0 1.00 878. 615.
118118
```
119119

120120
Estimation time can be improved further by providing a gradient
@@ -134,14 +134,14 @@ Which shows that the estimation time was dramatically improved, now
134134
``` r
135135
unlist(fit_grad@timing)
136136
#> warmup sampling
137-
#> 0.111 0.087
137+
#> 0.103 0.093
138138
summary(fit_grad)
139139
#> # A tibble: 3 × 10
140140
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
141141
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
142-
#> 1 lp__ -1.08e3 -1.08e3 1.02 0.741 -1.08e3 -1.08e3 1.00 572. 712.
143-
#> 2 pars[1] 1.01e1 1.01e1 0.0928 0.0943 9.97e0 1.03e1 1.00 950. 623.
144-
#> 3 pars[2] 2.10e0 2.10e0 0.0691 0.0696 1.99e0 2.22e0 1.00 725. 613.
142+
#> 1 lp__ -1.05e3 -1.05e3 0.952 0.763 -1.05e3 -1.05e3 1.01 500. 675.
143+
#> 2 pars[1] 9.97e0 9.97e0 0.0905 0.0954 9.82e0 1.01e1 1.000 830. 531.
144+
#> 3 pars[2] 1.96e0 1.96e0 0.0619 0.0616 1.87e0 2.07e0 1.00 1047. 640.
145145
```
146146

147147
### Optimization
@@ -158,11 +158,11 @@ opt_grad <- stan_optimize(loglik_fun, inits, additional_args = list(y),
158158

159159
``` r
160160
summary(opt_fd)
161-
#> lp__ pars[1] pars[2]
162-
#> 1 -1079.84 10.1221 2.09743
161+
#> lp__ pars[1] pars[2]
162+
#> 1 -1046.049 9.9691 1.96036
163163
summary(opt_grad)
164-
#> lp__ pars[1] pars[2]
165-
#> 1 -1079.84 10.1221 2.09743
164+
#> lp__ pars[1] pars[2]
165+
#> 1 -1046.049 9.9691 1.96036
166166
```
167167

168168
### Laplace Approximation
@@ -191,28 +191,28 @@ summary(lapl_num)
191191
#> # A tibble: 4 × 10
192192
#> variable mean median sd mad q5 q95 rhat ess_bulk
193193
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
194-
#> 1 log_p__ -1082. -1082. 2.39 2.18 -1087. -1080. 0.999 989.
195-
#> 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047.
196-
#> 3 pars[1] 10.0 10.0 0.0899 0.0867 9.85 10.1 1.00 933.
197-
#> 4 pars[2] 2.00 2.00 0.0626 0.0635 1.90 2.11 1.00 1051.
194+
#> 1 log_p__ -1477. -1475. 55.3 56.0 -1572. -1389. 1.00 986.
195+
#> 2 log_q__ -1.01 -0.695 1.01 0.743 -3.03 -0.0443 1.00 913.
196+
#> 3 pars[1] 10.0 10.00 0.335 0.343 9.47 10.5 0.999 831.
197+
#> 4 pars[2] 7.45 7.39 0.897 0.893 6.10 9.08 1.00 987.
198198
#> # ℹ 1 more variable: ess_tail <dbl>
199199
summary(lapl_opt)
200200
#> # A tibble: 4 × 10
201201
#> variable mean median sd mad q5 q95 rhat ess_bulk
202202
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
203-
#> 1 log_p__ -1080. -1080. 1.06 0.712 -1082. -1079. 0.999 1044.
204-
#> 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047.
205-
#> 3 pars[1] 10.1 10.1 0.0940 0.0897 9.96 10.3 1.00 932.
206-
#> 4 pars[2] 2.10 2.10 0.0688 0.0697 1.99 2.21 1.00 1051.
203+
#> 1 log_p__ -1458. -1457. 52.8 53.5 -1549. -1374. 1.00 986.
204+
#> 2 log_q__ -1.01 -0.695 1.01 0.743 -3.03 -0.0443 1.00 913.
205+
#> 3 pars[1] 9.97 9.97 0.321 0.329 9.46 10.5 0.999 830.
206+
#> 4 pars[2] 7.16 7.10 0.827 0.824 5.91 8.66 1.00 987.
207207
#> # ℹ 1 more variable: ess_tail <dbl>
208208
summary(lapl_est)
209209
#> # A tibble: 4 × 10
210210
#> variable mean median sd mad q5 q95 rhat ess_bulk
211211
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
212-
#> 1 log_p__ -1080. -1080. 1.06 0.712 -1082. -1079. 0.999 1044.
213-
#> 2 log_q__ -1.04 -0.692 1.04 0.716 -3.21 -0.0582 0.999 1047.
214-
#> 3 pars[1] 10.1 10.1 0.0940 0.0897 9.96 10.3 1.00 932.
215-
#> 4 pars[2] 2.10 2.10 0.0688 0.0697 1.99 2.21 1.00 1051.
212+
#> 1 log_p__ -1458. -1457. 52.8 53.5 -1549. -1374. 1.00 986.
213+
#> 2 log_q__ -1.01 -0.695 1.01 0.743 -3.03 -0.0443 1.00 913.
214+
#> 3 pars[1] 9.97 9.97 0.321 0.329 9.46 10.5 0.999 830.
215+
#> 4 pars[2] 7.16 7.10 0.827 0.824 5.91 8.66 1.00 987.
216216
#> # ℹ 1 more variable: ess_tail <dbl>
217217
```
218218

@@ -231,23 +231,23 @@ var_grad <- stan_variational(loglik_fun, inits, additional_args = list(y),
231231
``` r
232232
summary(var_fd)
233233
#> # A tibble: 5 × 10
234-
#> variable mean median sd mad q5 q95 rhat ess_bulk
235-
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
236-
#> 1 lp__ 0 0 0 0 0 0 NA NA
237-
#> 2 log_p__ -1081. -1080. 1.33 0.986 -1083. -1079. 0.999 997.
238-
#> 3 log_g__ -1.03 -0.714 1.03 0.731 -3.29 -0.0486 1.00 959.
239-
#> 4 pars[1] 10.2 10.2 0.0869 0.0898 10.1 10.4 1.00 1012.
240-
#> 5 pars[2] 2.09 2.09 0.0650 0.0639 1.99 2.20 1.00 850.
234+
#> variable mean median sd mad q5 q95 rhat ess_bulk
235+
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
236+
#> 1 lp__ 0 0 0 0 0 0 NA NA
237+
#> 2 log_p__ -1047. -1046. 1.25 0.975 -1049. -1045. 1.00 1017.
238+
#> 3 log_g__ -0.978 -0.660 0.966 0.678 -2.84 -0.0566 1.00 1054.
239+
#> 4 pars[1] 10.0 10.0 0.0847 0.0877 9.88 10.2 0.999 1025.
240+
#> 5 pars[2] 1.92 1.92 0.0528 0.0523 1.83 2.01 1.00 1047.
241241
#> # ℹ 1 more variable: ess_tail <dbl>
242242
summary(var_grad)
243243
#> # A tibble: 5 × 10
244-
#> variable mean median sd mad q5 q95 rhat ess_bulk
245-
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
246-
#> 1 lp__ 0 0 0 0 0 0 NA NA
247-
#> 2 log_p__ -1081. -1080. 1.33 0.986 -1083. -1079. 0.999 997.
248-
#> 3 log_g__ -1.03 -0.714 1.03 0.731 -3.29 -0.0486 1.00 959.
249-
#> 4 pars[1] 10.2 10.2 0.0869 0.0898 10.1 10.4 1.00 1012.
250-
#> 5 pars[2] 2.09 2.09 0.0650 0.0639 1.99 2.20 1.00 850.
244+
#> variable mean median sd mad q5 q95 rhat ess_bulk
245+
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
246+
#> 1 lp__ 0 0 0 0 0 0 NA NA
247+
#> 2 log_p__ -1047. -1046. 1.25 0.975 -1049. -1045. 1.00 1017.
248+
#> 3 log_g__ -0.978 -0.660 0.966 0.678 -2.84 -0.0566 1.00 1054.
249+
#> 4 pars[1] 10.0 10.0 0.0847 0.0877 9.88 10.2 0.999 1025.
250+
#> 5 pars[2] 1.92 1.92 0.0528 0.0523 1.83 2.01 1.00 1047.
251251
#> # ℹ 1 more variable: ess_tail <dbl>
252252
```
253253

@@ -265,19 +265,23 @@ path_grad <- stan_pathfinder(loglik_fun, inits, additional_args = list(y),
265265

266266
``` r
267267
summary(path_fd)
268-
#> # A tibble: 4 × 10
269-
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
270-
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
271-
#> 1 lp_appr… 2.96e0 3.27e0 0.998 0.716 1.03e0 3.90e0 1.00 949. 909.
272-
#> 2 lp__ -1.08e3 -1.08e3 1.04 0.726 -1.08e3 -1.08e3 1.00 946. 820.
273-
#> 3 pars[1] 1.01e1 1.01e1 0.0955 0.0920 9.96e0 1.03e1 0.999 1004. 800.
274-
#> 4 pars[2] 2.10e0 2.11e0 0.0668 0.0695 1.99e0 2.21e0 1.00 998. 907.
268+
#> # A tibble: 5 × 10
269+
#> variable mean median sd mad q5 q95 rhat ess_bulk
270+
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
271+
#> 1 lp_approx__ 3.04 3.45 1.19 0.704 0.609 4.07 1.00 652.
272+
#> 2 lp__ -1046. -1046. 1.09 0.661 -1049. -1045. 1.00 653.
273+
#> 3 path__ 2.51 3 1.10 1.48 1 4 2.65 1.20
274+
#> 4 pars[1] 9.97 9.96 0.0872 0.0835 9.82 10.1 1.000 803.
275+
#> 5 pars[2] 1.96 1.96 0.0633 0.0606 1.86 2.07 1.00 734.
276+
#> # ℹ 1 more variable: ess_tail <dbl>
275277
summary(path_grad)
276-
#> # A tibble: 4 × 10
277-
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
278-
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
279-
#> 1 lp_appr… 2.96e0 3.27e0 0.998 0.716 1.03e0 3.90e0 1.00 949. 909.
280-
#> 2 lp__ -1.08e3 -1.08e3 1.04 0.726 -1.08e3 -1.08e3 1.00 946. 820.
281-
#> 3 pars[1] 1.01e1 1.01e1 0.0955 0.0920 9.96e0 1.03e1 0.999 1004. 800.
282-
#> 4 pars[2] 2.10e0 2.11e0 0.0668 0.0695 1.99e0 2.21e0 1.00 998. 907.
278+
#> # A tibble: 5 × 10
279+
#> variable mean median sd mad q5 q95 rhat ess_bulk
280+
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
281+
#> 1 lp_approx__ 3.04 3.45 1.19 0.704 0.609 4.07 1.00 652.
282+
#> 2 lp__ -1046. -1046. 1.09 0.661 -1049. -1045. 1.00 653.
283+
#> 3 path__ 2.51 3 1.10 1.48 1 4 2.65 1.20
284+
#> 4 pars[1] 9.97 9.96 0.0872 0.0835 9.82 10.1 1.000 803.
285+
#> 5 pars[2] 1.96 1.96 0.0633 0.0606 1.86 2.07 1.00 734.
286+
#> # ℹ 1 more variable: ess_tail <dbl>
283287
```

vignettes/Getting-Started.Rmd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ sigma <- c(15, 10, 16, 11, 9, 11, 10, 18)
7171

7272
### Specifying the Function
7373

74-
To specify this as a function compatible with `StanEstimators`, we need to define a function that takes in a vector of parameters as the first argument and returns a single value (generally the joint log-likelihood):
74+
To specify this as a function compatible with `StanEstimators`, we need to define a function that takes in a vector of parameters as the first argument and returns a single value (generally the unnormalized target log density):
7575

7676
```{r}
7777
eight_schools_lpdf <- function(v, y, sigma) {
@@ -127,7 +127,7 @@ summary(fit)
127127

128128
## Model Checking and Comparison - Leave-One-Out Cross-Validation (LOO-CV)
129129

130-
`StanEstimators` also supports the use of the [loo](https://mc-stan.org/loo/articles/loo2-example.html) package for model checking and comparison. To use this, we need to specify a function which returns the pointwise log-likelihood for each observation in the data - as our original function returns the sum of all log-likelihoods.
130+
`StanEstimators` also supports the use of the [loo](https://mc-stan.org/loo/articles/loo2-example.html) package for model checking and comparison. To use this, we need to specify a function which returns the pointwise unnormalized target log density for each observation in the data - as our original function returns the sum over all observations.
131131

132132
For our model, we can define this function as:
133133

@@ -141,7 +141,7 @@ eight_schools_pointwise <- function(v, y, sigma) {
141141
# https://mc-stan.org/docs/stan-users-guide/reparameterization.html
142142
theta <- mu + tau * eta
143143
144-
# Only the log-likelihood for the outcome variable
144+
# Only the density for the outcome variable
145145
dnorm(y, mean = theta, sd = sigma, log = TRUE)
146146
}
147147
```

0 commit comments

Comments
 (0)