Skip to content

Commit ef72bd3

Browse files
Updates
1 parent f65baec commit ef72bd3

File tree

5 files changed

+188
-26
lines changed

5 files changed

+188
-26
lines changed

doc/src/fstats_distributions.f90

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ pure function mvnd_get_means(this) result(rst)
896896
end function
897897

898898
! ------------------------------------------------------------------------------
899-
pure function mvnd_get_covariance(this) result(rst)
899+
function mvnd_get_covariance(this) result(rst)
900900
!! Gets the covariance matrix of the distribution.
901901
class(multivariate_normal_distribution), intent(in) :: this
902902
!! The multivariate_normal_distribution object.
@@ -906,8 +906,9 @@ pure function mvnd_get_covariance(this) result(rst)
906906
! Process
907907
integer(int32) :: n
908908
if (allocated(this%m_cov)) then
909-
n = size(this%m_cov)
910-
allocate(rst(n, n), source = this%m_cov)
909+
n = size(this%m_cov, 1)
910+
print "N = ", n
911+
rst = this%m_cov
911912
else
912913
allocate(rst(0, 0))
913914
end if

examples/mcmc_regression_example_2.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ program example
8282
solver%fcn => fit_fcn
8383

8484
! Define upper and lower limits for each parameter (optional)
85-
solver%upper_limits = [1.0d1, 0.0d0, 1.0d2]
86-
solver%lower_limits = [0.1d0, -1.0d0, 1.0d1]
85+
solver%upper_limits = [1.0d1, 1.0d1, 1.0d2]
86+
solver%lower_limits = [0.1d0, -1.0d1, 1.0d1]
8787

8888
! Define an initial guess
8989
xi = [1.0d0, -0.5d0, 2.0d1]

src/fstats_distributions.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ pure function mvnd_get_covariance(this) result(rst)
906906
! Process
907907
integer(int32) :: n
908908
if (allocated(this%m_cov)) then
909-
n = size(this%m_cov)
909+
n = size(this%m_cov, 1)
910910
allocate(rst(n, n), source = this%m_cov)
911911
else
912912
allocate(rst(0, 0))

src/fstats_mcmc.f90

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ module fstats_mcmc
3434
logical, private :: m_propDistInitialized = .false.
3535
!! Set to true if the proposal distribution object has been
3636
!! initialized; else, false.
37+
integer(int32), private :: m_accepted = 0
38+
!! The number of accepted steps.
3739
contains
3840
procedure, public :: get_state_variable_count => mh_get_nvars
3941
procedure, public :: get_chain_length => mh_get_chain_length
@@ -54,6 +56,8 @@ module fstats_mcmc
5456
procedure, public :: get_proposal_covariance => mh_get_prop_cov
5557
procedure, public :: set_proposal_covariance => mh_set_prop_cov
5658
procedure, public :: get_proposal_cholesky => mh_get_prop_chol_cov
59+
procedure, public :: evaluate_proposal_pdf => mh_eval_proposal
60+
procedure, public :: get_accepted_count => mh_get_num_accepted
5761

5862
! Private Routines
5963
procedure, private :: resize_buffer => mh_resize_buffer
@@ -262,6 +266,21 @@ function mh_proposal(this, xc) result(rst)
262266
rst = sample_normal_multivariate(this%m_propDist)
263267
end function
264268

269+
! ------------------------------------------------------------------------------
270+
pure function mh_eval_proposal(this, xc) result(rst)
271+
!! Evaluates the proposal distribution PDF at the specified set of
272+
!! variables.
273+
class(metropolis_hastings), intent(in) :: this
274+
!! The metropolis_hastings object.
275+
real(real64), intent(in), dimension(:) :: xc
276+
!! The array of variables to evaluate.
277+
real(real64) :: rst
278+
!! The value of the PDF at xc.
279+
280+
! Process
281+
rst = this%m_propDist%pdf(xc)
282+
end function
283+
265284
! ------------------------------------------------------------------------------
266285
function mh_hastings_ratio(this, xc, xp) result(rst)
267286
!! Evaluates the Hasting's ratio. If the proposal distribution is
@@ -363,6 +382,7 @@ subroutine mh_sample(this, xi, niter, err)
363382
npts = this%initial_iteration_estimate
364383
end if
365384
n = size(xi)
385+
this%m_accepted = 0
366386

367387
! Initialize the proposal distribution. Use an identity matrix for the
368388
! covariance matrix and assume a zero mean.
@@ -374,6 +394,7 @@ subroutine mh_sample(this, xi, niter, err)
374394
! Store the initial value
375395
call this%push_new_state(xi, err = errmgr)
376396
if (errmgr%has_error_occurred()) return
397+
this%m_accepted = 1
377398

378399
! Iteration Process
379400
xc = xi
@@ -399,6 +420,9 @@ subroutine mh_sample(this, xi, niter, err)
399420
xc = xp
400421
pc = pp
401422

423+
! Log the success
424+
this%m_accepted = this%m_accepted + 1
425+
402426
! Take additional actions on success???
403427
call this%on_acceptance(i, alpha, xc, xp, err = errmgr)
404428
if (errmgr%has_error_occurred()) return
@@ -649,5 +673,15 @@ pure function mh_get_prop_chol_cov(this) result(rst)
649673
end if
650674
end function
651675

676+
! ------------------------------------------------------------------------------
677+
pure function mh_get_num_accepted(this) result(rst)
678+
!! Gets the number of accepted steps.
679+
class(metropolis_hastings), intent(in) :: this
680+
!! The metropolis_hastings object.
681+
integer(int32) :: rst
682+
!! The number of accepted steps.
683+
rst = this%m_accepted
684+
end function
685+
652686
! ------------------------------------------------------------------------------
653687
end module

src/fstats_mcmc_fitting.f90

Lines changed: 147 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,24 @@ module fstats_mcmc_fitting
4444
real(real64), private, allocatable, dimension(:) :: m_f0
4545
!! An N-element array used for containing the current function
4646
!! estimate (N = size(x)).
47+
real(real64), private, allocatable, dimension(:) :: m_mean
48+
!! A NP-element array used to contain a running mean of each
49+
!! parameter.
4750

4851
! -----
4952
! Private Member Variables
50-
real(real64), private :: m_modelVariance = 1.0d0
51-
!! The variance of the residual error of the current model.
53+
real(real64), private :: m_dataVariance = 1.0d0
54+
!! The variance within the data set itself.
5255
contains
5356
procedure, public :: generate_proposal => mr_proposal
57+
procedure, public :: likelihood => mr_likelihood
5458
procedure, public :: target_distribution => mr_target
5559
procedure, public :: covariance_matrix => mr_covariance
5660
procedure, public :: compute_fit_statistics => mr_calc_regression_stats
57-
procedure, public :: get_target_variance => mr_get_target_variance
58-
procedure, public :: set_target_variance => mr_set_target_variance
61+
procedure, public :: get_data_variance => mr_get_data_variance
62+
procedure, public :: set_data_variance => mr_set_data_variance
63+
procedure, public :: on_acceptance => mr_on_success
64+
procedure, public :: push_new_state => mr_push
5965
end type
6066

6167
contains
@@ -90,21 +96,19 @@ function mr_proposal(this, xc) result(rst)
9096
end function
9197

9298
! ------------------------------------------------------------------------------
93-
! https://scalismo.org/docs/Tutorials/tutorial14
94-
function mr_target(this, x) result(rst)
95-
!! Returns the probability value from the target distribution given the
96-
!! current set of model parameters.
99+
function mr_likelihood(this, x) result(rst)
100+
!! Estimates the likelihood of the model.
97101
!!
98-
!! The probability value is determined as follows, assuming \(f(x)\)
99-
!! is the function value.
100-
!! $$ \prod_{i=1}^{n} p \left( y_{i} | \theta, x_{i} \right) =
101-
!! \prod_{i=1}^{n} N \left(y_{i} | f(x_{i}), \sigma^2 \right) $$.
102+
!! The likelihood is computed as follows assuming \(\sigma^2\) is known
103+
!! a priori.
104+
!! $$ L \left( \theta \right) = \prod_{i=1}^{n} N \left(y_{i} | f(x_{i}),
105+
!! \sigma^2 \right) $$
102106
class(mcmc_regression), intent(inout) :: this
103107
!! The mcmc_regression object.
104108
real(real64), intent(in), dimension(:) :: x
105109
!! The current set of model parameters.
106110
real(real64) :: rst
107-
!! The value of the probability density function being sampled.
111+
!! The likelihood value.
108112

109113
! Local Variables
110114
type(normal_distribution) :: dist
@@ -124,7 +128,7 @@ function mr_target(this, x) result(rst)
124128
temp = 1.0d0
125129
ep = 0
126130
rst = 1.0d0
127-
dist%standard_deviation = sqrt(this%get_target_variance())
131+
dist%standard_deviation = sqrt(this%get_data_variance())
128132
do i = 1, npts
129133
dist%mean_value = this%m_f0(i)
130134
p = dist%pdf(this%y(i))
@@ -147,6 +151,28 @@ function mr_target(this, x) result(rst)
147151
rst = temp * (1.0d1)**ep
148152
end function
149153

154+
! ------------------------------------------------------------------------------
155+
! https://scalismo.org/docs/Tutorials/tutorial14
156+
function mr_target(this, x) result(rst)
157+
!! Returns the probability value from the target distribution given the
158+
!! current set of model parameters.
159+
!!
160+
!! The probability value is determined as follows, assuming \(f(x)\)
161+
!! is the function value.
162+
!! $$ \prod_{i=1}^{n} p \left( y_{i} | \theta, x_{i} \right) =
163+
!! p \left( \theta, \sigma^2 \right)
164+
!! \prod_{i=1}^{n} N \left(y_{i} | f(x_{i}), \sigma^2 \right) $$.
165+
class(mcmc_regression), intent(inout) :: this
166+
!! The mcmc_regression object.
167+
real(real64), intent(in), dimension(:) :: x
168+
!! The current set of model parameters.
169+
real(real64) :: rst
170+
!! The value of the probability density function being sampled.
171+
172+
! Process
173+
rst = this%likelihood(x) * this%evaluate_proposal_pdf(x)
174+
end function
175+
150176
! ------------------------------------------------------------------------------
151177
function mr_covariance(this, xc, err) result(rst)
152178
!! Computes the covariance matrix for the model given the specified model
@@ -261,25 +287,126 @@ function mr_calc_regression_stats(this, xc, alpha, err) result(rst)
261287
end function
262288

263289
! ------------------------------------------------------------------------------
264-
pure function mr_get_target_variance(this) result(rst)
265-
!! Gets the variance of the target distribution.
290+
pure function mr_get_data_variance(this) result(rst)
291+
!! Gets the variance of the observed data.
266292
class(mcmc_regression), intent(in) :: this
267293
!! The mcmc_regression object.
268294
real(real64) :: rst
269295
!! The variance.
270296

271-
rst = this%m_modelVariance
297+
rst = this%m_dataVariance
272298
end function
273299

274300
! ------------------------------------------------------------------------------
275-
subroutine mr_set_target_variance(this, x)
276-
!! Sets the variance of the target distribution.
301+
subroutine mr_set_data_variance(this, x)
302+
!! Sets the variance of the observed data.
277303
class(mcmc_regression), intent(inout) :: this
278304
!! The mcmc_regression object.
279305
real(real64), intent(in) :: x
280306
!! The variance.
281307

282-
this%m_modelVariance = x
308+
this%m_dataVariance = x
309+
end subroutine
310+
311+
! ------------------------------------------------------------------------------
312+
subroutine mr_on_success(this, iter, alpha, xc, xp, err)
313+
!! Updates the covariance matrix of the proposal distribution upon a
314+
!! successful step. If overloaded, be sure to call the base method to
315+
!! retain the functionallity required to keep the covariance matrix
316+
!! up-to-date.
317+
class(mcmc_regression), intent(inout) :: this
318+
!! The mcmc_regression object.
319+
integer(int32), intent(in) :: iter
320+
!! The current iteration number.
321+
real(real64), intent(in) :: alpha
322+
!! The proposal probability term used for acceptance criteria.
323+
real(real64), intent(in), dimension(:) :: xc
324+
!! The current model parameter estimates.
325+
real(real64), intent(in), dimension(size(xc)) :: xp
326+
!! The recently accepted model parameter estimates.
327+
class(errors), intent(inout), optional, target :: err
328+
!! An error handling object.
329+
330+
! Local Variables
331+
integer(int32) :: i, j, n, np
332+
real(real64) :: nm1, nm2, ratio
333+
real(real64), allocatable, dimension(:) :: delta
334+
real(real64), allocatable, dimension(:,:) :: sig
335+
336+
! Updates the estimate of the covariance matrix by implementing Roberts &
337+
! Rosenthals adaptive approach.
338+
!
339+
! Parameters:
340+
! - xp: NP-by-1 array of the newest sampled points
341+
! - xm: NP-by-1 array of the updated mean over all samples
342+
! - sig: NP-by-NP old covariance matrix
343+
! - n: # of samples drawn
344+
!
345+
! C = (n - 2) / (n - 1) * sig + matmul(xp - xm, transpose(xp - xm)) / (n - 1)
346+
np = size(xc)
347+
n = this%get_chain_length()
348+
if (n == 1 .or. .not.allocated(this%m_mean)) then
349+
! No action is necessary
350+
return
351+
end if
352+
nm1 = n - 1.0d0
353+
nm2 = n - 2.0d0
354+
ratio = nm2 / nm1
355+
delta = xp - this%m_mean
356+
sig = this%get_proposal_covariance()
357+
358+
do j = 1, np
359+
do i = 1, np
360+
sig(i,j) = ratio * sig(i,j) + delta(i) * delta(j) / nm1
361+
end do
362+
end do
363+
364+
! Update the covariance matrix
365+
call this%set_proposal_covariance(sig, err = err)
366+
end subroutine
367+
368+
! ------------------------------------------------------------------------------
369+
subroutine mr_push(this, x, err)
370+
!! Pushes a new set of parameters onto the end of the chain buffer.
371+
class(mcmc_regression), intent(inout) :: this
372+
!! The mcmc_regression object.
373+
real(real64), intent(in), dimension(:) :: x
374+
!! The new N-element state array.
375+
class(errors), intent(inout), optional, target :: err
376+
!! An error handling object.
377+
378+
! Local Variables
379+
integer(int32) :: n, npts, flag
380+
class(errors), pointer :: errmgr
381+
type(errors), target :: deferr
382+
383+
! Initialization
384+
if (present(err)) then
385+
errmgr => err
386+
else
387+
errmgr => deferr
388+
end if
389+
390+
! Push the item onto the stack using the base method
391+
call this%metropolis_hastings%push_new_state(x, err = errmgr)
392+
if (errmgr%has_error_occurred()) return
393+
394+
! Update the running average term
395+
n = size(x)
396+
npts = this%get_chain_length()
397+
if (.not.allocated(this%m_mean)) then
398+
allocate(this%m_mean(n), stat = flag, source = x)
399+
if (flag /= 0) then
400+
call report_memory_error(errmgr, "mr_push", flag)
401+
return
402+
end if
403+
404+
! No more action is necessary - end here
405+
return
406+
end if
407+
408+
! Update the mean
409+
this%m_mean = (npts * this%m_mean + x) / (npts + 1.0d0)
283410
end subroutine
284411

285412
! ------------------------------------------------------------------------------

0 commit comments

Comments
 (0)