Skip to content

Commit 78edec0

Browse files
authored
ensure that correct initial params are used when re-fitting a ModelResult (#961)
ensure that correct initial params are used when re-fitting a ModeRresult
1 parent b72cfb2 commit 78edec0

File tree

2 files changed

+61
-5
lines changed

2 files changed

+61
-5
lines changed

lmfit/model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,10 +1255,14 @@ def __init__(self, left, right, op, **kws):
12551255
if 'nan_policy' not in kws:
12561256
kws['nan_policy'] = self.left.nan_policy
12571257

1258+
# CompositeModel cannot have a prefix.
1259+
if 'prefix' in kws:
1260+
warnings.warn("CompositeModel ignores `prefix` argument")
1261+
kws['prefix'] = ''
1262+
12581263
def _tmp(self, *args, **kws):
12591264
pass
12601265
Model.__init__(self, _tmp, **kws)
1261-
12621266
for side in (left, right):
12631267
prefix = side.prefix
12641268
for basename, hint in side.param_hints.items():
@@ -1548,7 +1552,10 @@ def fit(self, data=None, params=None, weights=None, method=None,
15481552
if data is not None:
15491553
self.data = data
15501554
if params is not None:
1551-
self.init_params = params
1555+
self.init_params = deepcopy(params)
1556+
else:
1557+
self.init_params = deepcopy(self.params)
1558+
15521559
if weights is not None:
15531560
self.weights = weights
15541561
if method is not None:
@@ -1559,8 +1566,8 @@ def fit(self, data=None, params=None, weights=None, method=None,
15591566
self.ci_out = None
15601567
self.userargs = (self.data, self.weights)
15611568
self.userkws.update(kwargs)
1562-
self.init_fit = self.model.eval(params=self.params, **self.userkws)
1563-
_ret = self.minimize(method=self.method)
1569+
self.init_fit = self.model.eval(params=self.init_params, **self.userkws)
1570+
_ret = self.minimize(method=self.method, params=self.init_params)
15641571
self.model.post_fit(_ret)
15651572
_ret.params.create_uvars(covar=_ret.covar)
15661573

tests/test_model.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from lmfit import Model, Parameters, models
1414
from lmfit.lineshapes import gaussian, lorentzian, step, voigt
1515
from lmfit.model import get_reducer, propagate_err
16-
from lmfit.models import GaussianModel, PseudoVoigtModel
16+
from lmfit.models import GaussianModel, PseudoVoigtModel, QuadraticModel
1717

1818

1919
@pytest.fixture()
@@ -1648,3 +1648,52 @@ def test_custom_variadic_model():
16481648
assert result.nfev > 7
16491649
assert_allclose(result.values['c0'], 5.0, 0.02, 0.02, '', True)
16501650
assert_allclose(result.values['c1'], 3.3, 0.02, 0.02, '', True)
1651+
1652+
1653+
def test_model_refitting():
1654+
"""Github #960"""
1655+
np.random.seed(0)
1656+
x = np.linspace(0, 100, 5001)
1657+
y = gaussian(x, amplitude=90, center=60, sigma=4) + 30 + 0.3*x - 0.0030*x*x
1658+
y += np.random.normal(size=5001, scale=0.5)
1659+
1660+
model = GaussianModel(prefix='peak_') + QuadraticModel(prefix='bkg_')
1661+
1662+
params = model.make_params(bkg_a=0, bkg_b=0, bkg_c=20, peak_amplitude=200,
1663+
peak_center=55, peak_sigma=10)
1664+
1665+
result = model.fit(y, params, x=x, method='powell')
1666+
assert result.chisqr > 12000.0
1667+
assert result.nfev > 500
1668+
assert result.params['peak_amplitude'].value > 500
1669+
assert result.params['peak_amplitude'].value < 5000
1670+
assert result.params['peak_sigma'].value > 10
1671+
assert result.params['peak_sigma'].value < 100
1672+
1673+
# now re-fit with LM
1674+
result.fit(y, x=x, method='leastsq')
1675+
1676+
assert result.nfev > 25
1677+
assert result.nfev < 200
1678+
assert result.chisqr < 2000.0
1679+
1680+
assert result.params['peak_amplitude'].value > 85
1681+
assert result.params['peak_amplitude'].value < 95
1682+
assert result.params['peak_sigma'].value > 3
1683+
assert result.params['peak_sigma'].value < 5
1684+
1685+
# and assert that the initial value are from the Powell result
1686+
assert result.init_values['peak_amplitude'] > 1500
1687+
assert result.init_values['peak_sigma'] > 25
1688+
1689+
params = model.make_params(bkg_a=0, bkg_b=-.02, bkg_c=26, peak_amplitude=20,
1690+
peak_center=62, peak_sigma=3)
1691+
1692+
# now re-fit with LM and these new params
1693+
result.fit(y, params, x=x, method='leastsq')
1694+
1695+
# and assert that the initial value are from the Powell result
1696+
assert result.init_values['peak_amplitude'] > 19
1697+
assert result.init_values['peak_amplitude'] < 21
1698+
assert result.init_values['peak_sigma'] > 2
1699+
assert result.init_values['peak_sigma'] < 4

0 commit comments

Comments
 (0)