Skip to content

Commit a53878e

Browse files
committed
TST: stats.fit: fix tests
1 parent 26445e5 commit a53878e

File tree

1 file changed

+56
-10
lines changed

1 file changed

+56
-10
lines changed

scipy/stats/tests/test_fit.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ def cases_test_fit_mle():
231231
# These fail default test or hang
232232
skip_basic_fit = {'argus', 'irwinhall', 'foldnorm', 'truncpareto',
233233
'truncweibull_min', 'ksone', 'levy_stable',
234-
'studentized_range', 'kstwo', 'arcsine',
234+
'studentized_range', 'kstwo',
235+
'beta', 'nakagami', 'truncnorm', # don't meet tolerance
235236
'poisson_binom'} # vector-valued shape parameter
236237

237238
# Please keep this list in alphabetical order...
@@ -282,7 +283,7 @@ def cases_test_fit_mse():
282283
'gausshyper', 'genhyperbolic', # integration warnings
283284
'tukeylambda', # close, but doesn't meet tolerance
284285
'vonmises', # can have negative CDF; doesn't play nice
285-
'argus', # doesn't meet tolerance; tested separately
286+
'arcsine', 'argus', 'powerlaw', # don't meet tolerance
286287
'poisson_binom', # vector-valued shape parameter
287288
}
288289

@@ -375,8 +376,8 @@ class TestFit:
375376
rtol = 1e-2
376377
tols = {'atol': atol, 'rtol': rtol}
377378

378-
def opt(self, *args, **kwds):
379-
return differential_evolution(*args, rng=1, **kwds)
379+
def opt(self, *args, rng=1, **kwds):
380+
return differential_evolution(*args, rng=rng, **kwds)
380381

381382
def test_dist_iv(self):
382383
message = "`dist` must be an instance of..."
@@ -494,7 +495,7 @@ def test_guess_iv(self):
494495
with pytest.warns(RuntimeWarning, match=message):
495496
stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
496497

497-
def basic_fit_test(self, dist_name, method):
498+
def basic_fit_test(self, dist_name, method, rng=1):
498499

499500
N = 5000
500501
dist_data = dict(distcont + distdiscrete)
@@ -530,11 +531,11 @@ def basic_fit_test(self, dist_name, method):
530531

531532
@pytest.mark.parametrize("dist_name", cases_test_fit_mle())
532533
def test_basic_fit_mle(self, dist_name):
533-
self.basic_fit_test(dist_name, "mle")
534+
self.basic_fit_test(dist_name, "mle", rng=5)
534535

535536
@pytest.mark.parametrize("dist_name", cases_test_fit_mse())
536537
def test_basic_fit_mse(self, dist_name):
537-
self.basic_fit_test(dist_name, "mse")
538+
self.basic_fit_test(dist_name, "mse", rng=2)
538539

539540
def test_arcsine(self):
540541
# Can't guarantee that all distributions will fit all data with
@@ -546,8 +547,9 @@ def test_arcsine(self):
546547
shapes = (1., 2.)
547548
data = dist.rvs(*shapes, size=N, random_state=rng)
548549
shape_bounds = {'loc': (0.1, 10), 'scale': (0.1, 10)}
549-
res = stats.fit(dist, data, shape_bounds, optimizer=self.opt)
550-
assert_nlff_less_or_close(dist, data, res.params, shapes, **self.tols)
550+
res = stats.fit(dist, data, shape_bounds, method='mse', optimizer=self.opt)
551+
assert_nlff_less_or_close(dist, data, res.params, shapes,
552+
nlff_name='_penalized_nlpsf', **self.tols)
551553

552554
@pytest.mark.parametrize("method", ('mle', 'mse'))
553555
def test_argus(self, method):
@@ -561,8 +563,24 @@ def test_argus(self, method):
561563
data = dist.rvs(*shapes, size=N, random_state=rng)
562564
shape_bounds = {'chi': (0.1, 10), 'loc': (0.1, 10), 'scale': (0.1, 10)}
563565
res = stats.fit(dist, data, shape_bounds, optimizer=self.opt, method=method)
566+
nlff_name = {'mle': 'nnlf', 'mse': '_penalized_nlpsf'}[method]
567+
assert_nlff_less_or_close(dist, data, res.params, shapes, **self.tols,
568+
nlff_name=nlff_name)
564569

565-
assert_nlff_less_or_close(dist, data, res.params, shapes, **self.tols)
570+
def test_beta(self):
571+
# Can't guarantee that all distributions will fit all data with
572+
# arbitrary bounds. This distribution just happens to fail above.
573+
# Try something slightly different.
574+
N = 1000
575+
rng = np.random.default_rng(self.seed)
576+
dist = stats.beta
577+
shapes = (2.3098496451481823, 0.62687954300963677, 1., 2.)
578+
data = dist.rvs(*shapes, size=N, random_state=rng)
579+
shape_bounds = {'a': (0.1, 10), 'b':(0.1, 10),
580+
'loc': (0.1, 10), 'scale': (0.1, 10)}
581+
res = stats.fit(dist, data, shape_bounds, method='mle', optimizer=self.opt)
582+
assert_nlff_less_or_close(dist, data, res.params, shapes,
583+
nlff_name='nnlf', **self.tols)
566584

567585
def test_foldnorm(self):
568586
# Can't guarantee that all distributions will fit all data with
@@ -578,6 +596,34 @@ def test_foldnorm(self):
578596

579597
assert_nlff_less_or_close(dist, data, res.params, shapes, **self.tols)
580598

599+
def test_nakagami(self):
600+
# Can't guarantee that all distributions will fit all data with
601+
# arbitrary bounds. This distribution just happens to fail above.
602+
# Try something slightly different.
603+
N = 1000
604+
rng = np.random.default_rng(self.seed)
605+
dist = stats.nakagami
606+
shapes = (4.9673794866666237, 1., 2.)
607+
data = dist.rvs(*shapes, size=N, random_state=rng)
608+
shape_bounds = {'nu':(0.1, 10), 'loc': (0.1, 10), 'scale': (0.1, 10)}
609+
res = stats.fit(dist, data, shape_bounds, method='mle', optimizer=self.opt)
610+
assert_nlff_less_or_close(dist, data, res.params, shapes,
611+
nlff_name='nnlf', **self.tols)
612+
613+
def test_powerlaw(self):
614+
# Can't guarantee that all distributions will fit all data with
615+
# arbitrary bounds. This distribution just happens to fail above.
616+
# Try something slightly different.
617+
N = 1000
618+
rng = np.random.default_rng(self.seed)
619+
dist = stats.powerlaw
620+
shapes = (1.6591133289905851, 1., 2.)
621+
data = dist.rvs(*shapes, size=N, random_state=rng)
622+
shape_bounds = {'a': (0.1, 10), 'loc': (0.1, 10), 'scale': (0.1, 10)}
623+
res = stats.fit(dist, data, shape_bounds, method='mse', optimizer=self.opt)
624+
assert_nlff_less_or_close(dist, data, res.params, shapes,
625+
nlff_name='_penalized_nlpsf', **self.tols)
626+
581627
def test_truncpareto(self):
582628
# Can't guarantee that all distributions will fit all data with
583629
# arbitrary bounds. This distribution just happens to fail above.

0 commit comments

Comments
 (0)