Skip to content

Commit f9d3276

Browse files
Add regression function argument passing.
1 parent ba93de5 commit f9d3276

File tree

3 files changed

+155
-25
lines changed

3 files changed

+155
-25
lines changed

examples/nl_regression_example.f90

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
module nl_example
22
use iso_fortran_env
33
contains
4-
subroutine exfun(x, p, f, stop)
4+
subroutine exfun(x, p, f, stop, args)
55
! Arguments
66
real(real64), intent(in) :: x(:), p(:)
77
real(real64), intent(out) :: f(:)
88
logical, intent(out) :: stop
9+
class(*), intent(inout), optional :: args
910

1011
! Function
1112
f = p(4) * x**3 + p(3) * x**2 + p(2) * x + p(1)

src/fstats_regression.f90

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ module fstats_regression
137137
end type
138138

139139
interface
140-
subroutine regression_function(xdata, params, f, stop)
140+
subroutine regression_function(xdata, params, f, stop, args)
141141
!! Defines the interface of a subroutine computing the function
142142
!! values at each of the N data points as part of a regression
143143
!! analysis.
@@ -154,6 +154,8 @@ subroutine regression_function(xdata, params, f, stop)
154154
!! set to true, the iteration process will terminate. If set
155155
!! to false, the iteration process will continue along as
156156
!! normal.
157+
class(*), intent(inout), optional :: args
158+
!! An optional argument allowing the passing in/out of data.
157159
end subroutine
158160

159161
subroutine iteration_update(iter, funvals, resid, params, step)
@@ -686,7 +688,7 @@ function calculate_regression_statistics(resid, params, c, alpha, err) &
686688

687689
! ------------------------------------------------------------------------------
688690
subroutine jacobian(fun, xdata, params, &
689-
jac, stop, f0, f1, step, err)
691+
jac, stop, f0, f1, step, args, err)
690692
!! Computes the Jacobian matrix for a nonlinear regression problem.
691693
procedure(regression_function), intent(in), pointer :: fun
692694
!! A pointer to the regression_function to evaluate.
@@ -711,6 +713,8 @@ subroutine jacobian(fun, xdata, params, &
711713
real(real64), intent(in), optional :: step
712714
!! The differentiation step size. The default is the square
713715
!! root of machine precision.
716+
class(*), intent(inout), optional :: args
717+
!! An optional argument allowing the passing in/out of data.
714718
class(errors), intent(inout), optional, target :: err
715719
!! A mechanism for communicating errors and warnings to the
716720
!! caller. Possible warning and error codes are as follows.
@@ -762,7 +766,7 @@ subroutine jacobian(fun, xdata, params, &
762766
allocate(f0a(m), stat = flag)
763767
if (flag /= 0) go to 20
764768
f0p(1:m) => f0a
765-
call fun(xdata, params, f0p, stop)
769+
call fun(xdata, params, f0p, stop, args = args)
766770
if (stop) return
767771
end if
768772
if (present(f1)) then
@@ -786,7 +790,7 @@ subroutine jacobian(fun, xdata, params, &
786790

787791
! Compute the Jacobian
788792
call jacobian_finite_diff(fun, xdata, params, f0p, jac, f1p, &
789-
stop, h, work)
793+
stop, h, work, args = args)
790794

791795
! End
792796
return
@@ -800,7 +804,7 @@ subroutine jacobian(fun, xdata, params, &
800804
! ------------------------------------------------------------------------------
801805
subroutine nonlinear_least_squares(fun, x, y, params, ymod, &
802806
resid, weights, maxp, minp, stats, alpha, controls, settings, info, &
803-
status, cov, err)
807+
status, cov, args, err)
804808
!! Performs a nonlinear regression to fit a model using a version
805809
!! of the Levenberg-Marquardt algorithm.
806810
procedure(regression_function), intent(in), pointer :: fun
@@ -853,6 +857,9 @@ subroutine nonlinear_least_squares(fun, x, y, params, ymod, &
853857
real(real64), intent(out), optional, dimension(:,:) :: cov
854858
!! An optional N-by-N matrix that, if supplied, will be used to return
855859
!! the covariance matrix.
860+
class(*), intent(inout), optional :: args
861+
!! An optional argument allowing the passing in/out of data for the
862+
!! [[fun]] routine.
856863
class(errors), intent(inout), optional, target :: err
857864
!! A mechanism for communicating errors and warnings to the
858865
!! caller. Possible warning and error codes are as follows.
@@ -1035,7 +1042,7 @@ subroutine nonlinear_least_squares(fun, x, y, params, ymod, &
10351042

10361043
! Process
10371044
call lm_solve(fun, x, y, params, w, pmax, pmin, tol, opt, ymod, &
1038-
resid, JtWJ, inf, stop, errmgr, status)
1045+
resid, JtWJ, inf, stop, errmgr, status, args = args)
10391046

10401047
! Compute the covariance matrix
10411048
if (present(stats) .or. present(cov)) then
@@ -1124,7 +1131,7 @@ subroutine lm_set_default_settings(x)
11241131
! - stop: A flag allowing the user to terminate model execution
11251132
! - work: A workspace array for the model parameters (N-by-1)
11261133
subroutine jacobian_finite_diff(fun, xdata, params, f0, jac, f1, &
1127-
stop, step, work)
1134+
stop, step, work, args)
11281135
! Arguments
11291136
procedure(regression_function), intent(in), pointer :: fun
11301137
real(real64), intent(in), dimension(:) :: xdata, params
@@ -1133,6 +1140,7 @@ subroutine jacobian_finite_diff(fun, xdata, params, f0, jac, f1, &
11331140
real(real64), intent(out), dimension(:) :: f1, work
11341141
logical, intent(out) :: stop
11351142
real(real64), intent(in) :: step
1143+
class(*), intent(inout), optional :: args
11361144

11371145
! Local Variables
11381146
integer(int32) :: i, n
@@ -1147,7 +1155,7 @@ subroutine jacobian_finite_diff(fun, xdata, params, f0, jac, f1, &
11471155
work = params
11481156
do i = 1, n
11491157
work(i) = work(i) + step
1150-
call fun(xdata, work, f1, stop)
1158+
call fun(xdata, work, f1, stop, args = args)
11511159
if (stop) return
11521160

11531161
jac(:,i) = (f1 - f0) / step
@@ -1218,7 +1226,7 @@ subroutine broyden_update(pOld, yOld, jac, p, y, dp, dy)
12181226
! - mwork: A workspace matrix (N-by-M)
12191227
! - update: Reset to false if a Jacobian evaluation was performed.
12201228
subroutine lm_matrix(fun, xdata, ydata, pOld, yOld, dX2, jac, p, weights, &
1221-
neval, update, step, JtWJ, JtWdy, X2, yNew, stop, work, mwork)
1229+
neval, update, step, JtWJ, JtWdy, X2, yNew, stop, work, mwork, args)
12221230
! Arguments
12231231
procedure(regression_function), pointer :: fun
12241232
real(real64), intent(in), dimension(:) :: xdata, ydata, pOld, yOld, &
@@ -1234,6 +1242,7 @@ subroutine lm_matrix(fun, xdata, ydata, pOld, yOld, dX2, jac, p, weights, &
12341242
real(real64), intent(out), dimension(:) :: yNew
12351243
logical, intent(out) :: stop
12361244
real(real64), intent(out), target, dimension(:) :: work
1245+
class(*), intent(inout), optional :: args
12371246

12381247
! Local Variables
12391248
integer(int32) :: m, n
@@ -1246,15 +1255,15 @@ subroutine lm_matrix(fun, xdata, ydata, pOld, yOld, dX2, jac, p, weights, &
12461255
w2(1:n) => work(m+1:n+m)
12471256

12481257
! Perform the next function evaluation
1249-
call fun(xdata, p, yNew, stop)
1258+
call fun(xdata, p, yNew, stop, args = args)
12501259
neval = neval + 1
12511260
if (stop) return
12521261

12531262
! Update or recompute the Jacobian matrix
12541263
if (dX2 > 0 .or. update) then
12551264
! Recompute the Jacobian
12561265
call jacobian_finite_diff(fun, xdata, p, yNew, jac, w1, &
1257-
stop, step, w2)
1266+
stop, step, w2, args = args)
12581267
neval = neval + n
12591268
if (stop) return
12601269
update = .false.
@@ -1311,7 +1320,7 @@ subroutine lm_matrix(fun, xdata, ydata, pOld, yOld, dX2, jac, p, weights, &
13111320
! - err: An error handling mechanism
13121321
subroutine lm_iter(fun, xdata, ydata, p, neval, niter, update, lambda, &
13131322
maxP, minP, weights, JtWJ, JtWdy, h, pNew, deltaY, yNew, X2, X2Old, &
1314-
alpha, stop, iwork, err, status)
1323+
alpha, stop, iwork, err, status, args)
13151324
! Arguments
13161325
procedure(regression_function), pointer :: fun
13171326
real(real64), intent(in) :: xdata(:), ydata(:), p(:), maxP(:), &
@@ -1326,6 +1335,7 @@ subroutine lm_iter(fun, xdata, ydata, p, neval, niter, update, lambda, &
13261335
integer(int32), intent(out) :: iwork(:)
13271336
class(errors), intent(inout) :: err
13281337
procedure(iteration_update), intent(in), pointer, optional :: status
1338+
class(*), intent(inout), optional :: args
13291339

13301340
! Local Variables
13311341
integer(int32) :: i, n
@@ -1367,7 +1377,7 @@ subroutine lm_iter(fun, xdata, ydata, p, neval, niter, update, lambda, &
13671377
end do
13681378

13691379
! Update the residual error
1370-
call fun(xdata, pNew, yNew, stop)
1380+
call fun(xdata, pNew, yNew, stop, args = args)
13711381
neval = neval + 1
13721382
deltaY = ydata - yNew
13731383
if (stop) return
@@ -1385,7 +1395,7 @@ subroutine lm_iter(fun, xdata, ydata, p, neval, niter, update, lambda, &
13851395
pNew(i) = min(max(minP(i), p(i) + h(i)), maxP(i))
13861396
end do
13871397

1388-
call fun(xdata, pNew, yNew, stop)
1398+
call fun(xdata, pNew, yNew, stop, args = args)
13891399
if (stop) return
13901400
neval = neval + 1
13911401
deltaY = ydata - yNew
@@ -1421,7 +1431,7 @@ subroutine lm_iter(fun, xdata, ydata, p, neval, niter, update, lambda, &
14211431
! - stop: A flag allowing the user to terminate model execution
14221432
! - err: An error handling object
14231433
subroutine lm_solve(fun, xdata, ydata, p, weights, maxP, minP, controls, &
1424-
opt, y, resid, JtWJ, info, stop, err, status)
1434+
opt, y, resid, JtWJ, info, stop, err, status, args)
14251435
! Arguments
14261436
procedure(regression_function), intent(in), pointer :: fun
14271437
real(real64), intent(in) :: xdata(:), ydata(:), weights(:), maxP(:), &
@@ -1434,6 +1444,7 @@ subroutine lm_solve(fun, xdata, ydata, p, weights, maxP, minP, controls, &
14341444
logical, intent(out) :: stop
14351445
class(errors), intent(inout) :: err
14361446
procedure(iteration_update), intent(in), pointer, optional :: status
1447+
class(*), intent(inout), optional :: args
14371448

14381449
! Local Variables
14391450
logical :: update
@@ -1470,12 +1481,12 @@ subroutine lm_solve(fun, xdata, ydata, p, weights, maxP, minP, controls, &
14701481
if (flag /= 0) go to 10
14711482

14721483
! Perform an initial function evaluation
1473-
call fun(xdata, p, y, stop)
1484+
call fun(xdata, p, y, stop, args = args)
14741485
neval = 1
14751486

14761487
! Evaluate the problem matrices
14771488
call lm_matrix(fun, xdata, ydata, pOld, yOld, 1.0d0, J, p, weights, &
1478-
neval, update, step, JtWJ, JtWdy, X2, y, stop, work, mwork)
1489+
neval, update, step, JtWJ, JtWdy, X2, y, stop, work, mwork, args = args)
14791490
if (stop) go to 5
14801491
X2Old = X2
14811492
JtWJc = JtWJ
@@ -1495,15 +1506,16 @@ subroutine lm_solve(fun, xdata, ydata, p, weights, maxP, minP, controls, &
14951506
! update the new parameter estimates
14961507
call lm_iter(fun, xdata, ydata, p, neval, niter, opt%method, &
14971508
lambda, maxP, minP, weights, JtWJc, JtWdy, h, pTry, resid, &
1498-
yTemp, X2Try, X2Old, alpha, stop, iwork, err, status)
1509+
yTemp, X2Try, X2Old, alpha, stop, iwork, err, status, args = args)
14991510
if (stop) go to 5
15001511
if (err%has_error_occurred()) return
15011512

15021513
! Update the Chi-squared estimate, update the damping parameter
15031514
! lambda, and, if necessary, update the matrices
15041515
call lm_update(fun, xdata, ydata, pOld, p, pTry, yOld, y, h, dX2, &
15051516
X2Old, X2, X2Try, lambda, alpha, nu, JtWdy, JtWJ, J, weights, &
1506-
niter, neval, update, step, work, mwork, controls, opt, stop)
1517+
niter, neval, update, step, work, mwork, controls, opt, stop, &
1518+
args = args)
15071519
if (stop) go to 5
15081520
JtWJc = JtWJ
15091521

@@ -1548,7 +1560,7 @@ subroutine lm_solve(fun, xdata, ydata, p, weights, maxP, minP, controls, &
15481560
!
15491561
subroutine lm_update(fun, xdata, ydata, pOld, p, pTry, yOld, y, h, dX2, &
15501562
X2old, X2, X2try, lambda, alpha, nu, JtWdy, JtWJ, J, weights, niter, &
1551-
neval, update, step, work, mwork, controls, opt, stop)
1563+
neval, update, step, work, mwork, controls, opt, stop, args)
15521564
! Arguments
15531565
procedure(regression_function), intent(in), pointer :: fun
15541566
real(real64), intent(in) :: xdata(:), ydata(:), X2try, h(:), step, &
@@ -1562,6 +1574,7 @@ subroutine lm_update(fun, xdata, ydata, pOld, p, pTry, yOld, y, h, dX2, &
15621574
class(iteration_controls), intent(in) :: controls
15631575
class(lm_solver_options), intent(in) :: opt
15641576
logical, intent(out) :: stop
1577+
class(*), intent(inout), optional :: args
15651578

15661579
! Local Variables
15671580
integer(int32) :: n
@@ -1588,7 +1601,8 @@ subroutine lm_update(fun, xdata, ydata, pOld, p, pTry, yOld, y, h, dX2, &
15881601

15891602
! Recompute the matrices
15901603
call lm_matrix(fun, xdata, ydata, pOld, yOld, dX2, J, p, weights, &
1591-
neval, update, step, JtWJ, JtWdy, X2, y, stop, work, mwork)
1604+
neval, update, step, JtWJ, JtWdy, X2, y, stop, work, mwork, &
1605+
args = args)
15921606
if (stop) return
15931607

15941608
! Decrease lambda
@@ -1608,7 +1622,7 @@ subroutine lm_update(fun, xdata, ydata, pOld, p, pTry, yOld, y, h, dX2, &
16081622
if (mod(niter, 2 * n) /= 0) then
16091623
call lm_matrix(fun, xdata, ydata, pOld, yOld, -1.0d0, J, p, &
16101624
weights, neval, update, step, JtWJ, JtWdy, dX2, y, stop, &
1611-
work, mwork)
1625+
work, mwork, args = args)
16121626
if (stop) return
16131627
end if
16141628

0 commit comments

Comments
 (0)