Skip to content

Commit afb3591

Browse files
committed
merge the style changes from stable-0.8
2 parents ebb4879 + edc36d9 commit afb3591

40 files changed

+1245
-888
lines changed

.pre-commit-config.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
repos:
2+
- repo: https://github.com/ambv/black
3+
rev: 19.3b0
4+
hooks:
5+
- id: black
6+
language_version: python3.7
7+
- repo: https://github.com/pre-commit/pre-commit-hooks
8+
rev: v2.1.0
9+
hooks:
10+
- id: end-of-file-fixer
11+
- id: trailing-whitespace
12+
- repo: https://gitlab.com/pycqa/flake8
13+
rev: 3.7.4
14+
hooks:
15+
- id: flake8
16+
args: ['--max-line-length=500', '--ignore=E203,E266,E501,W503', '--max-complexity=18', '--select=B,C,E,F,W,T4,B9']

README.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ please setup the git filter by executing
113113
114114
in the repository.
115115

116+
We implement several other checks in order to maintain a consistent code style. We do this using [pre-commit`](https://pre-commit.com), execute
117+
118+
.. code:: bash
119+
120+
pre-commit install
121+
122+
in the repository.
123+
116124
Credits
117125
-------
118126

adaptive/__init__.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,52 @@
44

55
from adaptive import learner, runner, utils
66
from adaptive._version import __version__
7-
from adaptive.learner import (AverageLearner, BalancingLearner, BaseLearner,
8-
DataSaver, IntegratorLearner, Learner1D,
9-
Learner2D, LearnerND, make_datasaver)
10-
from adaptive.notebook_integration import (active_plotting_tasks, live_plot,
11-
notebook_extension)
7+
from adaptive.learner import (
8+
AverageLearner,
9+
BalancingLearner,
10+
BaseLearner,
11+
DataSaver,
12+
IntegratorLearner,
13+
Learner1D,
14+
Learner2D,
15+
LearnerND,
16+
make_datasaver,
17+
)
18+
from adaptive.notebook_integration import (
19+
active_plotting_tasks,
20+
live_plot,
21+
notebook_extension,
22+
)
1223
from adaptive.runner import AsyncRunner, BlockingRunner, Runner
1324

25+
__all__ = [
26+
"learner",
27+
"runner",
28+
"utils",
29+
"__version__",
30+
"AverageLearner",
31+
"BalancingLearner",
32+
"BaseLearner",
33+
"DataSaver",
34+
"IntegratorLearner",
35+
"Learner1D",
36+
"Learner2D",
37+
"LearnerND",
38+
"make_datasaver",
39+
"active_plotting_tasks",
40+
"live_plot",
41+
"notebook_extension",
42+
"AsyncRunner",
43+
"BlockingRunner",
44+
"Runner",
45+
]
46+
1447
with suppress(ImportError):
1548
# Only available if 'scikit-optimize' is installed
16-
from adaptive.learner import SKOptLearner
49+
from adaptive.learner import SKOptLearner # noqa: F401
1750

51+
__all__.append("SKOptLearner")
1852

19-
del _version
20-
del notebook_integration # to avoid confusion with `notebook_extension`
53+
# to avoid confusion with `notebook_extension` and `__version__`
54+
del _version # noqa: F821
55+
del notebook_integration # noqa: F821

adaptive/_version.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from setuptools.command.sdist import sdist as sdist_orig
1010

11-
Version = namedtuple('Version', ('release', 'dev', 'labels'))
11+
Version = namedtuple("Version", ("release", "dev", "labels"))
1212

1313
# No public API
1414
__all__ = []
@@ -17,12 +17,12 @@
1717
package_name = os.path.basename(package_root)
1818
distr_root = os.path.dirname(package_root)
1919

20-
STATIC_VERSION_FILE = '_static_version.py'
20+
STATIC_VERSION_FILE = "_static_version.py"
2121

2222

2323
def get_version(version_file=STATIC_VERSION_FILE):
2424
version_info = get_static_version_info(version_file)
25-
version = version_info['version']
25+
version = version_info["version"]
2626
if version == "__use_git__":
2727
version = get_version_from_git()
2828
if not version:
@@ -36,43 +36,45 @@ def get_version(version_file=STATIC_VERSION_FILE):
3636

3737
def get_static_version_info(version_file=STATIC_VERSION_FILE):
3838
version_info = {}
39-
with open(os.path.join(package_root, version_file), 'rb') as f:
39+
with open(os.path.join(package_root, version_file), "rb") as f:
4040
exec(f.read(), {}, version_info)
4141
return version_info
4242

4343

4444
def version_is_from_git(version_file=STATIC_VERSION_FILE):
45-
return get_static_version_info(version_file)['version'] == '__use_git__'
45+
return get_static_version_info(version_file)["version"] == "__use_git__"
4646

4747

4848
def pep440_format(version_info):
4949
release, dev, labels = version_info
5050

5151
version_parts = [release]
5252
if dev:
53-
if release.endswith('-dev') or release.endswith('.dev'):
53+
if release.endswith("-dev") or release.endswith(".dev"):
5454
version_parts.append(dev)
5555
else: # prefer PEP440 over strict adhesion to semver
56-
version_parts.append(f'.dev{dev}')
56+
version_parts.append(f".dev{dev}")
5757

5858
if labels:
59-
version_parts.append('+')
59+
version_parts.append("+")
6060
version_parts.append(".".join(labels))
6161

6262
return "".join(version_parts)
6363

6464

6565
def get_version_from_git():
6666
try:
67-
p = subprocess.Popen(['git', 'rev-parse', '--show-toplevel'],
68-
cwd=distr_root,
69-
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
67+
p = subprocess.Popen(
68+
["git", "rev-parse", "--show-toplevel"],
69+
cwd=distr_root,
70+
stdout=subprocess.PIPE,
71+
stderr=subprocess.PIPE,
72+
)
7073
except OSError:
7174
return
7275
if p.wait() != 0:
7376
return
74-
if not os.path.samefile(p.communicate()[0].decode().rstrip('\n'),
75-
distr_root):
77+
if not os.path.samefile(p.communicate()[0].decode().rstrip("\n"), distr_root):
7678
# The top-level directory of the current Git repository is not the same
7779
# as the root directory of the distribution: do not extract the
7880
# version from Git.
@@ -81,12 +83,14 @@ def get_version_from_git():
8183
# git describe --first-parent does not take into account tags from branches
8284
# that were merged-in. The '--long' flag gets us the 'dev' version and
8385
# git hash, '--always' returns the git hash even if there are no tags.
84-
for opts in [['--first-parent'], []]:
86+
for opts in [["--first-parent"], []]:
8587
try:
8688
p = subprocess.Popen(
87-
['git', 'describe', '--long', '--always'] + opts,
89+
["git", "describe", "--long", "--always"] + opts,
8890
cwd=distr_root,
89-
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
91+
stdout=subprocess.PIPE,
92+
stderr=subprocess.PIPE,
93+
)
9094
except OSError:
9195
return
9296
if p.wait() == 0:
@@ -97,17 +101,17 @@ def get_version_from_git():
97101
description = (
98102
p.communicate()[0]
99103
.decode()
100-
.strip('v') # Tags can have a leading 'v', but the version should not
101-
.rstrip('\n')
102-
.rsplit('-', 2) # Split the latest tag, commits since tag, and hash
104+
.strip("v") # Tags can have a leading 'v', but the version should not
105+
.rstrip("\n")
106+
.rsplit("-", 2) # Split the latest tag, commits since tag, and hash
103107
)
104108

105109
try:
106110
release, dev, git = description
107111
except ValueError: # No tags, only the git hash
108112
# prepend 'g' to match with format returned by 'git describe'
109-
git = 'g{}'.format(*description)
110-
release = 'unknown'
113+
git = "g{}".format(*description)
114+
release = "unknown"
111115
dev = None
112116

113117
labels = []
@@ -117,12 +121,12 @@ def get_version_from_git():
117121
labels.append(git)
118122

119123
try:
120-
p = subprocess.Popen(['git', 'diff', '--quiet'], cwd=distr_root)
124+
p = subprocess.Popen(["git", "diff", "--quiet"], cwd=distr_root)
121125
except OSError:
122-
labels.append('confused') # This should never happen.
126+
labels.append("confused") # This should never happen.
123127
else:
124128
if p.wait() == 1:
125-
labels.append('dirty')
129+
labels.append("dirty")
126130

127131
return Version(release, dev, labels)
128132

@@ -134,25 +138,25 @@ def get_version_from_git():
134138
# if it is not tagged.
135139
def get_version_from_git_archive(version_info):
136140
try:
137-
refnames = version_info['refnames']
138-
git_hash = version_info['git_hash']
141+
refnames = version_info["refnames"]
142+
git_hash = version_info["git_hash"]
139143
except KeyError:
140144
# These fields are not present if we are running from an sdist.
141145
# Execution should never reach here, though
142146
return None
143147

144-
if git_hash.startswith('$Format') or refnames.startswith('$Format'):
148+
if git_hash.startswith("$Format") or refnames.startswith("$Format"):
145149
# variables not expanded during 'git archive'
146150
return None
147151

148-
VTAG = 'tag: v'
152+
VTAG = "tag: v"
149153
refs = {r.strip() for r in refnames.split(",")}
150-
version_tags = {r[len(VTAG):] for r in refs if r.startswith(VTAG)}
154+
version_tags = {r[len(VTAG) :] for r in refs if r.startswith(VTAG)}
151155
if version_tags:
152156
release, *_ = sorted(version_tags) # prefer e.g. "2.0" over "2.0rc1"
153157
return Version(release, dev=None, labels=None)
154158
else:
155-
return Version('unknown', dev=None, labels=[f'g{git_hash}'])
159+
return Version("unknown", dev=None, labels=[f"g{git_hash}"])
156160

157161

158162
__version__ = get_version()
@@ -162,30 +166,31 @@ def get_version_from_git_archive(version_info):
162166
# which can be used from setup.py. The 'package_name' and
163167
# '__version__' module globals are used (but not modified).
164168

169+
165170
def _write_version(fname):
166171
# This could be a hard link, so try to delete it first. Is there any way
167172
# to do this atomically together with opening?
168173
try:
169174
os.remove(fname)
170175
except OSError:
171176
pass
172-
with open(fname, 'w') as f:
173-
f.write("# This file has been created by setup.py.\n"
174-
"version = '{}'\n".format(__version__))
177+
with open(fname, "w") as f:
178+
f.write(
179+
"# This file has been created by setup.py.\n"
180+
"version = '{}'\n".format(__version__)
181+
)
175182

176183

177184
class _build_py(build_py_orig):
178185
def run(self):
179186
super().run()
180-
_write_version(os.path.join(self.build_lib, package_name,
181-
STATIC_VERSION_FILE))
187+
_write_version(os.path.join(self.build_lib, package_name, STATIC_VERSION_FILE))
182188

183189

184190
class _sdist(sdist_orig):
185191
def make_release_tree(self, base_dir, files):
186192
super().make_release_tree(base_dir, files)
187-
_write_version(os.path.join(base_dir, package_name,
188-
STATIC_VERSION_FILE))
193+
_write_version(os.path.join(base_dir, package_name, STATIC_VERSION_FILE))
189194

190195

191196
cmdclass = dict(sdist=_sdist, build_py=_build_py)

adaptive/learner/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,20 @@
1111
from adaptive.learner.learner2D import Learner2D
1212
from adaptive.learner.learnerND import LearnerND
1313

14+
__all__ = [
15+
"AverageLearner",
16+
"BalancingLearner",
17+
"BaseLearner",
18+
"DataSaver",
19+
"make_datasaver",
20+
"IntegratorLearner",
21+
"Learner1D",
22+
"Learner2D",
23+
"LearnerND",
24+
]
25+
1426
with suppress(ImportError):
1527
# Only available if 'scikit-optimize' is installed
16-
from adaptive.learner.skopt_learner import SKOptLearner
28+
from adaptive.learner.skopt_learner import SKOptLearner # noqa: F401
29+
30+
__all__.append("SKOptLearner")

adaptive/learner/average_learner.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class AverageLearner(BaseLearner):
3434

3535
def __init__(self, function, atol=None, rtol=None):
3636
if atol is None and rtol is None:
37-
raise Exception('At least one of `atol` and `rtol` should be set.')
37+
raise Exception("At least one of `atol` and `rtol` should be set.")
3838
if atol is None:
3939
atol = np.inf
4040
if rtol is None:
@@ -58,9 +58,11 @@ def ask(self, n, tell_pending=True):
5858

5959
if any(p in self.data or p in self.pending_points for p in points):
6060
# This means some of the points `< self.n_requested` do not exist.
61-
points = list(set(range(self.n_requested + n))
62-
- set(self.data)
63-
- set(self.pending_points))[:n]
61+
points = list(
62+
set(range(self.n_requested + n))
63+
- set(self.data)
64+
- set(self.pending_points)
65+
)[:n]
6466

6567
loss_improvements = [self._loss_improvement(n) / n] * n
6668
if tell_pending:
@@ -76,7 +78,7 @@ def tell(self, n, value):
7678
self.data[n] = value
7779
self.pending_points.discard(n)
7880
self.sum_f += value
79-
self.sum_f_sq += value**2
81+
self.sum_f_sq += value ** 2
8082
self.npoints += 1
8183

8284
def tell_pending(self, n):
@@ -94,7 +96,7 @@ def std(self):
9496
n = self.npoints
9597
if n < 2:
9698
return np.inf
97-
numerator = self.sum_f_sq - n * self.mean**2
99+
numerator = self.sum_f_sq - n * self.mean ** 2
98100
if numerator < 0:
99101
# in this case the numerator ~ -1e-15
100102
return 0
@@ -109,8 +111,9 @@ def loss(self, real=True, *, n=None):
109111
if n < 2:
110112
return np.inf
111113
standard_error = self.std / sqrt(n)
112-
return max(standard_error / self.atol,
113-
standard_error / abs(self.mean) / self.rtol)
114+
return max(
115+
standard_error / self.atol, standard_error / abs(self.mean) / self.rtol
116+
)
114117

115118
def _loss_improvement(self, n):
116119
loss = self.loss()

0 commit comments

Comments
 (0)