Skip to content

Commit 906ab15

Browse files
Add argument passing to nl regression routine
1 parent 498496a commit 906ab15

File tree

4 files changed

+156
-24
lines changed

4 files changed

+156
-24
lines changed

examples/nl_regression_example.f90

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
module nl_example
22
use iso_fortran_env
33
contains
4-
subroutine exfun(x, p, f, stop)
4+
subroutine exfun(x, p, f, stop, args)
55
! Arguments
66
real(real64), intent(in) :: x(:), p(:)
77
real(real64), intent(out) :: f(:)
88
logical, intent(out) :: stop
9+
class(*), intent(inout), optional :: args
910

1011
! Function
1112
f = p(4) * x**3 + p(3) * x**2 + p(2) * x + p(1)

src/fstats_regression.f90

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ module fstats_regression
137137
end type
138138

139139
interface
140-
subroutine regression_function(xdata, params, f, stop)
140+
subroutine regression_function(xdata, params, f, stop, args)
141141
!! Defines the interface of a subroutine computing the function
142142
!! values at each of the N data points as part of a regression
143143
!! analysis.
@@ -154,6 +154,9 @@ subroutine regression_function(xdata, params, f, stop)
154154
!! set to true, the iteration process will terminate. If set
155155
!! to false, the iteration process will continue along as
156156
!! normal.
157+
class(*), intent(inout), optional :: args
158+
!! An optional parameter that can be used to pass data
159+
!! in and out of the routine.
157160
end subroutine
158161

159162
subroutine iteration_update(iter, funvals, resid, params, step)
@@ -686,7 +689,7 @@ function calculate_regression_statistics(resid, params, c, alpha, err) &
686689

687690
! ------------------------------------------------------------------------------
688691
subroutine jacobian(fun, xdata, params, &
689-
jac, stop, f0, f1, step, err)
692+
jac, stop, f0, f1, step, args, err)
690693
!! Computes the Jacobian matrix for a nonlinear regression problem.
691694
procedure(regression_function), intent(in), pointer :: fun
692695
!! A pointer to the regression_function to evaluate.
@@ -711,6 +714,9 @@ subroutine jacobian(fun, xdata, params, &
711714
real(real64), intent(in), optional :: step
712715
!! The differentiation step size. The default is the square
713716
!! root of machine precision.
717+
class(*), intent(inout), optional :: args
718+
!! An optional parameter that can be used to pass data
719+
!! in and out of the routine.
714720
class(errors), intent(inout), optional, target :: err
715721
!! A mechanism for communicating errors and warnings to the
716722
!! caller. Possible warning and error codes are as follows.
@@ -762,7 +768,7 @@ subroutine jacobian(fun, xdata, params, &
762768
allocate(f0a(m), stat = flag)
763769
if (flag /= 0) go to 20
764770
f0p(1:m) => f0a
765-
call fun(xdata, params, f0p, stop)
771+
call fun(xdata, params, f0p, stop, args)
766772
if (stop) return
767773
end if
768774
if (present(f1)) then
@@ -800,7 +806,7 @@ subroutine jacobian(fun, xdata, params, &
800806
! ------------------------------------------------------------------------------
801807
subroutine nonlinear_least_squares(fun, x, y, params, ymod, &
802808
resid, weights, maxp, minp, stats, alpha, controls, settings, info, &
803-
status, cov, err)
809+
status, cov, args, err)
804810
!! Performs a nonlinear regression to fit a model using a version
805811
!! of the Levenberg-Marquardt algorithm.
806812
procedure(regression_function), intent(in), pointer :: fun
@@ -853,6 +859,9 @@ subroutine nonlinear_least_squares(fun, x, y, params, ymod, &
853859
real(real64), intent(out), optional, dimension(:,:) :: cov
854860
!! An optional N-by-N matrix that, if supplied, will be used to return
855861
!! the covariance matrix.
862+
class(*), intent(inout), optional :: args
863+
!! An optional parameter that can be used to pass data
864+
!! in and out of the [[fun]] routine.
856865
class(errors), intent(inout), optional, target :: err
857866
!! A mechanism for communicating errors and warnings to the
858867
!! caller. Possible warning and error codes are as follows.
@@ -1035,7 +1044,7 @@ subroutine nonlinear_least_squares(fun, x, y, params, ymod, &
10351044

10361045
! Process
10371046
call lm_solve(fun, x, y, params, w, pmax, pmin, tol, opt, ymod, &
1038-
resid, JtWJ, inf, stop, errmgr, status)
1047+
resid, JtWJ, inf, stop, errmgr, status, args = args)
10391048

10401049
! Compute the covariance matrix
10411050
if (present(stats) .or. present(cov)) then
@@ -1124,7 +1133,7 @@ subroutine lm_set_default_settings(x)
11241133
! - stop: A flag allowing the user to terminate model execution
11251134
! - work: A workspace array for the model parameters (N-by-1)
11261135
subroutine jacobian_finite_diff(fun, xdata, params, f0, jac, f1, &
1127-
stop, step, work)
1136+
stop, step, work, args)
11281137
! Arguments
11291138
procedure(regression_function), intent(in), pointer :: fun
11301139
real(real64), intent(in) :: xdata(:), params(:)
@@ -1133,6 +1142,7 @@ subroutine jacobian_finite_diff(fun, xdata, params, f0, jac, f1, &
11331142
real(real64), intent(out) :: f1(:), work(:)
11341143
logical, intent(out) :: stop
11351144
real(real64), intent(in) :: step
1145+
class(*), intent(inout), optional :: args
11361146

11371147
! Local Variables
11381148
integer(int32) :: i, n
@@ -1147,7 +1157,7 @@ subroutine jacobian_finite_diff(fun, xdata, params, f0, jac, f1, &
11471157
work = params
11481158
do i = 1, n
11491159
work(i) = work(i) + step
1150-
call fun(xdata, work, f1, stop)
1160+
call fun(xdata, work, f1, stop, args)
11511161
if (stop) return
11521162

11531163
jac(:,i) = (f1 - f0) / step
@@ -1218,7 +1228,7 @@ subroutine broyden_update(pOld, yOld, jac, p, y, dp, dy)
12181228
! - mwork: A workspace matrix (N-by-M)
12191229
! - update: Reset to false if a Jacobian evaluation was performed.
12201230
subroutine lm_matrix(fun, xdata, ydata, pOld, yOld, dX2, jac, p, weights, &
1221-
neval, update, step, JtWJ, JtWdy, X2, yNew, stop, work, mwork)
1231+
neval, update, step, JtWJ, JtWdy, X2, yNew, stop, work, mwork, args)
12221232
! Arguments
12231233
procedure(regression_function), pointer :: fun
12241234
real(real64), intent(in) :: xdata(:), ydata(:), pOld(:), yOld(:), &
@@ -1231,6 +1241,7 @@ subroutine lm_matrix(fun, xdata, ydata, pOld, yOld, dX2, jac, p, weights, &
12311241
real(real64), intent(out) :: X2, mwork(:,:), yNew(:)
12321242
logical, intent(out) :: stop
12331243
real(real64), intent(out), target :: work(:)
1244+
class(*), intent(inout), optional :: args
12341245

12351246
! Local Variables
12361247
integer(int32) :: m, n
@@ -1243,15 +1254,15 @@ subroutine lm_matrix(fun, xdata, ydata, pOld, yOld, dX2, jac, p, weights, &
12431254
w2(1:n) => work(m+1:n+m)
12441255

12451256
! Perform the next function evaluation
1246-
call fun(xdata, p, yNew, stop)
1257+
call fun(xdata, p, yNew, stop, args)
12471258
neval = neval + 1
12481259
if (stop) return
12491260

12501261
! Update or recompute the Jacobian matrix
12511262
if (dX2 > 0 .or. update) then
12521263
! Recompute the Jacobian
12531264
call jacobian_finite_diff(fun, xdata, p, yNew, jac, w1, &
1254-
stop, step, w2)
1265+
stop, step, w2, args)
12551266
neval = neval + n
12561267
if (stop) return
12571268
update = .false.
@@ -1308,7 +1319,7 @@ subroutine lm_matrix(fun, xdata, ydata, pOld, yOld, dX2, jac, p, weights, &
13081319
! - err: An error handling mechanism
13091320
subroutine lm_iter(fun, xdata, ydata, p, neval, niter, update, lambda, &
13101321
maxP, minP, weights, JtWJ, JtWdy, h, pNew, deltaY, yNew, X2, X2Old, &
1311-
alpha, stop, iwork, err, status)
1322+
alpha, stop, iwork, err, status, args)
13121323
! Arguments
13131324
procedure(regression_function), pointer :: fun
13141325
real(real64), intent(in) :: xdata(:), ydata(:), p(:), maxP(:), &
@@ -1323,6 +1334,7 @@ subroutine lm_iter(fun, xdata, ydata, p, neval, niter, update, lambda, &
13231334
integer(int32), intent(out) :: iwork(:)
13241335
class(errors), intent(inout) :: err
13251336
procedure(iteration_update), intent(in), pointer, optional :: status
1337+
class(*), intent(inout), optional :: args
13261338

13271339
! Local Variables
13281340
integer(int32) :: i, n
@@ -1364,7 +1376,7 @@ subroutine lm_iter(fun, xdata, ydata, p, neval, niter, update, lambda, &
13641376
end do
13651377

13661378
! Update the residual error
1367-
call fun(xdata, pNew, yNew, stop)
1379+
call fun(xdata, pNew, yNew, stop, args)
13681380
neval = neval + 1
13691381
deltaY = ydata - yNew
13701382
if (stop) return
@@ -1382,7 +1394,7 @@ subroutine lm_iter(fun, xdata, ydata, p, neval, niter, update, lambda, &
13821394
pNew(i) = min(max(minP(i), p(i) + h(i)), maxP(i))
13831395
end do
13841396

1385-
call fun(xdata, pNew, yNew, stop)
1397+
call fun(xdata, pNew, yNew, stop, args)
13861398
if (stop) return
13871399
neval = neval + 1
13881400
deltaY = ydata - yNew
@@ -1418,7 +1430,7 @@ subroutine lm_iter(fun, xdata, ydata, p, neval, niter, update, lambda, &
14181430
! - stop: A flag allowing the user to terminate model execution
14191431
! - err: An error handling object
14201432
subroutine lm_solve(fun, xdata, ydata, p, weights, maxP, minP, controls, &
1421-
opt, y, resid, JtWJ, info, stop, err, status)
1433+
opt, y, resid, JtWJ, info, stop, err, status, args)
14221434
! Arguments
14231435
procedure(regression_function), intent(in), pointer :: fun
14241436
real(real64), intent(in) :: xdata(:), ydata(:), weights(:), maxP(:), &
@@ -1431,6 +1443,7 @@ subroutine lm_solve(fun, xdata, ydata, p, weights, maxP, minP, controls, &
14311443
logical, intent(out) :: stop
14321444
class(errors), intent(inout) :: err
14331445
procedure(iteration_update), intent(in), pointer, optional :: status
1446+
class(*), intent(inout), optional :: args
14341447

14351448
! Local Variables
14361449
logical :: update
@@ -1467,12 +1480,12 @@ subroutine lm_solve(fun, xdata, ydata, p, weights, maxP, minP, controls, &
14671480
if (flag /= 0) go to 10
14681481

14691482
! Perform an initial function evaluation
1470-
call fun(xdata, p, y, stop)
1483+
call fun(xdata, p, y, stop, args)
14711484
neval = 1
14721485

14731486
! Evaluate the problem matrices
14741487
call lm_matrix(fun, xdata, ydata, pOld, yOld, 1.0d0, J, p, weights, &
1475-
neval, update, step, JtWJ, JtWdy, X2, y, stop, work, mwork)
1488+
neval, update, step, JtWJ, JtWdy, X2, y, stop, work, mwork, args)
14761489
if (stop) go to 5
14771490
X2Old = X2
14781491
JtWJc = JtWJ
@@ -1492,15 +1505,15 @@ subroutine lm_solve(fun, xdata, ydata, p, weights, maxP, minP, controls, &
14921505
! update the new parameter estimates
14931506
call lm_iter(fun, xdata, ydata, p, neval, niter, opt%method, &
14941507
lambda, maxP, minP, weights, JtWJc, JtWdy, h, pTry, resid, &
1495-
yTemp, X2Try, X2Old, alpha, stop, iwork, err, status)
1508+
yTemp, X2Try, X2Old, alpha, stop, iwork, err, status, args)
14961509
if (stop) go to 5
14971510
if (err%has_error_occurred()) return
14981511

14991512
! Update the Chi-squared estimate, update the damping parameter
15001513
! lambda, and, if necessary, update the matrices
15011514
call lm_update(fun, xdata, ydata, pOld, p, pTry, yOld, y, h, dX2, &
15021515
X2Old, X2, X2Try, lambda, alpha, nu, JtWdy, JtWJ, J, weights, &
1503-
niter, neval, update, step, work, mwork, controls, opt, stop)
1516+
niter, neval, update, step, work, mwork, controls, opt, stop, args)
15041517
if (stop) go to 5
15051518
JtWJc = JtWJ
15061519

@@ -1545,7 +1558,7 @@ subroutine lm_solve(fun, xdata, ydata, p, weights, maxP, minP, controls, &
15451558
!
15461559
subroutine lm_update(fun, xdata, ydata, pOld, p, pTry, yOld, y, h, dX2, &
15471560
X2old, X2, X2try, lambda, alpha, nu, JtWdy, JtWJ, J, weights, niter, &
1548-
neval, update, step, work, mwork, controls, opt, stop)
1561+
neval, update, step, work, mwork, controls, opt, stop, args)
15491562
! Arguments
15501563
procedure(regression_function), intent(in), pointer :: fun
15511564
real(real64), intent(in) :: xdata(:), ydata(:), X2try, h(:), step, &
@@ -1559,6 +1572,7 @@ subroutine lm_update(fun, xdata, ydata, pOld, p, pTry, yOld, y, h, dX2, &
15591572
class(iteration_controls), intent(in) :: controls
15601573
class(lm_solver_options), intent(in) :: opt
15611574
logical, intent(out) :: stop
1575+
class(*), intent(inout), optional :: args
15621576

15631577
! Local Variables
15641578
integer(int32) :: n
@@ -1585,7 +1599,7 @@ subroutine lm_update(fun, xdata, ydata, pOld, p, pTry, yOld, y, h, dX2, &
15851599

15861600
! Recompute the matrices
15871601
call lm_matrix(fun, xdata, ydata, pOld, yOld, dX2, J, p, weights, &
1588-
neval, update, step, JtWJ, JtWdy, X2, y, stop, work, mwork)
1602+
neval, update, step, JtWJ, JtWdy, X2, y, stop, work, mwork, args)
15891603
if (stop) return
15901604

15911605
! Decrease lambda
@@ -1605,7 +1619,7 @@ subroutine lm_update(fun, xdata, ydata, pOld, p, pTry, yOld, y, h, dX2, &
16051619
if (mod(niter, 2 * n) /= 0) then
16061620
call lm_matrix(fun, xdata, ydata, pOld, yOld, -1.0d0, J, p, &
16071621
weights, neval, update, step, JtWJ, JtWdy, dX2, y, stop, &
1608-
work, mwork)
1622+
work, mwork, args)
16091623
if (stop) return
16101624
end if
16111625

0 commit comments

Comments
 (0)