Skip to content

Commit eacba8b

Browse files
Add example
1 parent 81341db commit eacba8b

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

examples/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,9 @@ add_executable(mcmc_regression_example mcmc_regression_example.f90)
6565
target_link_libraries(mcmc_regression_example fstats)
6666
target_link_libraries(mcmc_regression_example ${fplot_LIBRARY})
6767
target_include_directories(mcmc_regression_example PUBLIC ${fplot_INCLUDE_DIR})
68+
69+
# MCMC Regression Example 2
70+
add_executable(mcmc_regression_example_2 mcmc_regression_example_2.f90)
71+
target_link_libraries(mcmc_regression_example_2 fstats)
72+
target_link_libraries(mcmc_regression_example_2 ${fplot_LIBRARY})
73+
target_include_directories(mcmc_regression_example_2 PUBLIC ${fplot_INCLUDE_DIR})
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
module functions
2+
use iso_fortran_env
3+
implicit none
4+
5+
contains
6+
7+
subroutine fit_fcn(x, p, f, stop)
8+
real(real64), intent(in), dimension(:) :: x
9+
!! The independent variable data array.
10+
real(real64), intent(in), dimension(:) :: p
11+
!! The array of model parameters.
12+
real(real64), intent(out), dimension(:) :: f
13+
!! The function values evaluated at x.
14+
logical, intent(out) :: stop
15+
!! Set to true to force a stop to the process; else,
16+
!! set to false to proceed as normal.
17+
18+
! The function to fit
19+
f = p(1) * exp(p(2) * x) * sin(p(3) * x)
20+
stop = .false.
21+
end subroutine
22+
23+
end module
24+
25+
26+
! ----------
27+
program example
28+
use iso_fortran_env
29+
use functions
30+
use fstats
31+
use fplot_core
32+
implicit none
33+
34+
! Parameters
35+
integer(int32), parameter :: npts = 100
36+
real(real64), parameter :: dt = 1.0d-2
37+
character, parameter :: tab = achar(9)
38+
character, parameter :: nl = new_line('a')
39+
40+
! Model Parameters
41+
real(real64), parameter :: p1 = 1.5d0
42+
real(real64), parameter :: p2 = -5.0d-1
43+
real(real64), parameter :: p3 = 5.0d1
44+
45+
! Noise Properties
46+
real(real64), parameter :: sigma = 1.0d-1
47+
real(real64), parameter :: mu = 0.0d0
48+
real(real64), parameter :: range = 2.0d-1
49+
50+
! Regression Parameters
51+
real(real64), parameter :: s11 = 1.0d-1
52+
real(real64), parameter :: s22 = 1.0d-1
53+
real(real64), parameter :: s33 = 1.0d-1
54+
55+
! Local Variables
56+
logical :: stop
57+
integer(int32) :: i, burnin
58+
real(real64) :: t(npts), x(npts), noise(npts), xi(3), s(3, 3), mdl(3), &
59+
f(npts)
60+
real(real64), allocatable, dimension(:,:) :: chain
61+
type(normal_distribution) :: ndist
62+
type(mcmc_regression) :: solver
63+
type(regression_statistics), allocatable, dimension(:) :: stats
64+
65+
! Plot Variables
66+
type(multiplot) :: mplt
67+
type(plot_2d) :: plt, plt1, plt2, plt3
68+
type(plot_data_2d) :: pd1, pd2
69+
class(terminal), pointer :: term
70+
class(plot_axis), pointer :: x1, x2, x3
71+
72+
! Build the signal and corrupt it a bit with some noise
73+
t = (/ (i * dt, i = 0, npts - 1) /)
74+
ndist%mean_value = mu
75+
ndist%standard_deviation = sigma
76+
noise = rejection_sample(ndist, npts, -range, range)
77+
x = p1 * exp(p2 * t) * sin(p3 * t) + noise
78+
79+
! Set up the regression solver
80+
solver%x = t
81+
solver%y = x
82+
solver%fcn => fit_fcn
83+
84+
! Define upper and lower limits for each parameter (optional)
85+
solver%upper_limits = [1.0d1, 0.0d0, 1.0d2]
86+
solver%lower_limits = [0.1d0, -1.0d0, 1.0d1]
87+
88+
! Define an initial guess
89+
xi = [1.0d0, -0.5d0, 2.0d1]
90+
91+
! Set up the proposal distribution for the solver
92+
s = reshape([&
93+
s11, 0.0d0, 0.0d0, &
94+
0.0d0, s22, 0.0d0, &
95+
0.0d0, 0.0d0, s33], &
96+
[3, 3] &
97+
)
98+
call solver%initialize_proposal(xi, s)
99+
100+
! Compute the fit - sample 100,000 times
101+
call solver%sample(xi, niter = 100000)
102+
103+
! Get the chain
104+
chain = solver%get_chain()
105+
106+
! Extract the model - use the mean values and ignore the initial
107+
! burn-in. Notice, the burn-in section can be ignored via the call to
108+
! get_chain above by using the optional argument "bin" to define the
109+
! percentage of the chain to effectively throw away. I'm choosing to
110+
! do this way to illustrate the full chain in the plots.
111+
burnin = 3 * size(chain, 1) / 4
112+
mdl = [ &
113+
mean(chain(burnin:,1)), &
114+
mean(chain(burnin:,2)), &
115+
mean(chain(burnin:,3)) &
116+
]
117+
118+
! Evaluate the model
119+
call fit_fcn(t, mdl, f, stop)
120+
121+
! Compute the fit statistics
122+
stats = solver%compute_fit_statistics(mdl)
123+
124+
! Display the model parameters and stats
125+
print 100, ( &
126+
"Coefficient ", i, ":" // nl // &
127+
tab // "Value: ", mdl(i), nl // &
128+
tab // "Standard Error: ", stats(i)%standard_error, nl // &
129+
tab // "Confidence Interval: +/-", stats(i)%confidence_interval, nl // &
130+
tab // "T-Statistic: ", stats(i)%t_statistic, nl // &
131+
tab // "P-Value: ", stats(i)%probability, &
132+
i = 1, size(stats) &
133+
)
134+
135+
! ----------
136+
! Plot the fit
137+
call plt%initialize()
138+
call pd1%define_data(t, f)
139+
call plt%push(pd1)
140+
call pd2%define_data(t, x)
141+
call pd2%set_draw_line(.false.)
142+
call pd2%set_draw_markers(.true.)
143+
call pd2%set_marker_style(MARKER_FILLED_CIRCLE)
144+
call pd2%set_marker_scaling(0.5)
145+
call plt%push(pd2)
146+
call plt%draw()
147+
148+
! ----------
149+
! Plot the chains
150+
call mplt%initialize(3, 1)
151+
call plt1%initialize()
152+
call plt2%initialize()
153+
call plt3%initialize()
154+
x1 => plt1%get_x_axis()
155+
x2 => plt2%get_x_axis()
156+
x3 => plt3%get_x_axis()
157+
term => mplt%get_terminal()
158+
call term%set_window_height(800)
159+
call term%set_window_width(1000)
160+
call x1%set_use_default_tic_label_format(.false.)
161+
call x1%set_tic_label_format("%0.0e")
162+
call plt1%set_title("p_1")
163+
call plt2%set_title("p_2")
164+
call plt3%set_title("p_3")
165+
call pd1%define_data(chain(:,1))
166+
call plt1%push(pd1)
167+
call pd1%define_data(chain(:,2))
168+
call plt2%push(pd1)
169+
call pd1%define_data(chain(:,3))
170+
call plt3%push(pd1)
171+
call mplt%set(1, 1, plt1)
172+
call mplt%set(2, 1, plt2)
173+
call mplt%set(3, 1, plt3)
174+
call mplt%draw()
175+
176+
! -----
177+
100 format(A, I0, A F6.3, A, F6.3, A, F6.3, A, F8.3, A, F6.3)
178+
end program

0 commit comments

Comments
 (0)