Skip to content

Commit 716dbce

Browse files
committed
fix all flake8 issues and run pre-commit filters
1 parent f5063fb commit 716dbce

33 files changed

+1130
-852
lines changed

adaptive/_version.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from setuptools.command.sdist import sdist as sdist_orig
1010

11-
Version = namedtuple('Version', ('release', 'dev', 'labels'))
11+
Version = namedtuple("Version", ("release", "dev", "labels"))
1212

1313
# No public API
1414
__all__ = []
@@ -17,12 +17,12 @@
1717
package_name = os.path.basename(package_root)
1818
distr_root = os.path.dirname(package_root)
1919

20-
STATIC_VERSION_FILE = '_static_version.py'
20+
STATIC_VERSION_FILE = "_static_version.py"
2121

2222

2323
def get_version(version_file=STATIC_VERSION_FILE):
2424
version_info = get_static_version_info(version_file)
25-
version = version_info['version']
25+
version = version_info["version"]
2626
if version == "__use_git__":
2727
version = get_version_from_git()
2828
if not version:
@@ -36,43 +36,45 @@ def get_version(version_file=STATIC_VERSION_FILE):
3636

3737
def get_static_version_info(version_file=STATIC_VERSION_FILE):
3838
version_info = {}
39-
with open(os.path.join(package_root, version_file), 'rb') as f:
39+
with open(os.path.join(package_root, version_file), "rb") as f:
4040
exec(f.read(), {}, version_info)
4141
return version_info
4242

4343

4444
def version_is_from_git(version_file=STATIC_VERSION_FILE):
45-
return get_static_version_info(version_file)['version'] == '__use_git__'
45+
return get_static_version_info(version_file)["version"] == "__use_git__"
4646

4747

4848
def pep440_format(version_info):
4949
release, dev, labels = version_info
5050

5151
version_parts = [release]
5252
if dev:
53-
if release.endswith('-dev') or release.endswith('.dev'):
53+
if release.endswith("-dev") or release.endswith(".dev"):
5454
version_parts.append(dev)
5555
else: # prefer PEP440 over strict adhesion to semver
56-
version_parts.append(f'.dev{dev}')
56+
version_parts.append(f".dev{dev}")
5757

5858
if labels:
59-
version_parts.append('+')
59+
version_parts.append("+")
6060
version_parts.append(".".join(labels))
6161

6262
return "".join(version_parts)
6363

6464

6565
def get_version_from_git():
6666
try:
67-
p = subprocess.Popen(['git', 'rev-parse', '--show-toplevel'],
68-
cwd=distr_root,
69-
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
67+
p = subprocess.Popen(
68+
["git", "rev-parse", "--show-toplevel"],
69+
cwd=distr_root,
70+
stdout=subprocess.PIPE,
71+
stderr=subprocess.PIPE,
72+
)
7073
except OSError:
7174
return
7275
if p.wait() != 0:
7376
return
74-
if not os.path.samefile(p.communicate()[0].decode().rstrip('\n'),
75-
distr_root):
77+
if not os.path.samefile(p.communicate()[0].decode().rstrip("\n"), distr_root):
7678
# The top-level directory of the current Git repository is not the same
7779
# as the root directory of the distribution: do not extract the
7880
# version from Git.
@@ -81,12 +83,14 @@ def get_version_from_git():
8183
# git describe --first-parent does not take into account tags from branches
8284
# that were merged-in. The '--long' flag gets us the 'dev' version and
8385
# git hash, '--always' returns the git hash even if there are no tags.
84-
for opts in [['--first-parent'], []]:
86+
for opts in [["--first-parent"], []]:
8587
try:
8688
p = subprocess.Popen(
87-
['git', 'describe', '--long', '--always'] + opts,
89+
["git", "describe", "--long", "--always"] + opts,
8890
cwd=distr_root,
89-
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
91+
stdout=subprocess.PIPE,
92+
stderr=subprocess.PIPE,
93+
)
9094
except OSError:
9195
return
9296
if p.wait() == 0:
@@ -97,17 +101,17 @@ def get_version_from_git():
97101
description = (
98102
p.communicate()[0]
99103
.decode()
100-
.strip('v') # Tags can have a leading 'v', but the version should not
101-
.rstrip('\n')
102-
.rsplit('-', 2) # Split the latest tag, commits since tag, and hash
104+
.strip("v") # Tags can have a leading 'v', but the version should not
105+
.rstrip("\n")
106+
.rsplit("-", 2) # Split the latest tag, commits since tag, and hash
103107
)
104108

105109
try:
106110
release, dev, git = description
107111
except ValueError: # No tags, only the git hash
108112
# prepend 'g' to match with format returned by 'git describe'
109-
git = 'g{}'.format(*description)
110-
release = 'unknown'
113+
git = "g{}".format(*description)
114+
release = "unknown"
111115
dev = None
112116

113117
labels = []
@@ -117,12 +121,12 @@ def get_version_from_git():
117121
labels.append(git)
118122

119123
try:
120-
p = subprocess.Popen(['git', 'diff', '--quiet'], cwd=distr_root)
124+
p = subprocess.Popen(["git", "diff", "--quiet"], cwd=distr_root)
121125
except OSError:
122-
labels.append('confused') # This should never happen.
126+
labels.append("confused") # This should never happen.
123127
else:
124128
if p.wait() == 1:
125-
labels.append('dirty')
129+
labels.append("dirty")
126130

127131
return Version(release, dev, labels)
128132

@@ -134,25 +138,25 @@ def get_version_from_git():
134138
# if it is not tagged.
135139
def get_version_from_git_archive(version_info):
136140
try:
137-
refnames = version_info['refnames']
138-
git_hash = version_info['git_hash']
141+
refnames = version_info["refnames"]
142+
git_hash = version_info["git_hash"]
139143
except KeyError:
140144
# These fields are not present if we are running from an sdist.
141145
# Execution should never reach here, though
142146
return None
143147

144-
if git_hash.startswith('$Format') or refnames.startswith('$Format'):
148+
if git_hash.startswith("$Format") or refnames.startswith("$Format"):
145149
# variables not expanded during 'git archive'
146150
return None
147151

148-
VTAG = 'tag: v'
152+
VTAG = "tag: v"
149153
refs = {r.strip() for r in refnames.split(",")}
150-
version_tags = {r[len(VTAG):] for r in refs if r.startswith(VTAG)}
154+
version_tags = {r[len(VTAG) :] for r in refs if r.startswith(VTAG)}
151155
if version_tags:
152156
release, *_ = sorted(version_tags) # prefer e.g. "2.0" over "2.0rc1"
153157
return Version(release, dev=None, labels=None)
154158
else:
155-
return Version('unknown', dev=None, labels=[f'g{git_hash}'])
159+
return Version("unknown", dev=None, labels=[f"g{git_hash}"])
156160

157161

158162
__version__ = get_version()
@@ -162,30 +166,31 @@ def get_version_from_git_archive(version_info):
162166
# which can be used from setup.py. The 'package_name' and
163167
# '__version__' module globals are used (but not modified).
164168

169+
165170
def _write_version(fname):
166171
# This could be a hard link, so try to delete it first. Is there any way
167172
# to do this atomically together with opening?
168173
try:
169174
os.remove(fname)
170175
except OSError:
171176
pass
172-
with open(fname, 'w') as f:
173-
f.write("# This file has been created by setup.py.\n"
174-
"version = '{}'\n".format(__version__))
177+
with open(fname, "w") as f:
178+
f.write(
179+
"# This file has been created by setup.py.\n"
180+
"version = '{}'\n".format(__version__)
181+
)
175182

176183

177184
class _build_py(build_py_orig):
178185
def run(self):
179186
super().run()
180-
_write_version(os.path.join(self.build_lib, package_name,
181-
STATIC_VERSION_FILE))
187+
_write_version(os.path.join(self.build_lib, package_name, STATIC_VERSION_FILE))
182188

183189

184190
class _sdist(sdist_orig):
185191
def make_release_tree(self, base_dir, files):
186192
super().make_release_tree(base_dir, files)
187-
_write_version(os.path.join(base_dir, package_name,
188-
STATIC_VERSION_FILE))
193+
_write_version(os.path.join(base_dir, package_name, STATIC_VERSION_FILE))
189194

190195

191196
cmdclass = dict(sdist=_sdist, build_py=_build_py)

adaptive/learner/average_learner.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class AverageLearner(BaseLearner):
3434

3535
def __init__(self, function, atol=None, rtol=None):
3636
if atol is None and rtol is None:
37-
raise Exception('At least one of `atol` and `rtol` should be set.')
37+
raise Exception("At least one of `atol` and `rtol` should be set.")
3838
if atol is None:
3939
atol = np.inf
4040
if rtol is None:
@@ -58,9 +58,11 @@ def ask(self, n, tell_pending=True):
5858

5959
if any(p in self.data or p in self.pending_points for p in points):
6060
# This means some of the points `< self.n_requested` do not exist.
61-
points = list(set(range(self.n_requested + n))
62-
- set(self.data)
63-
- set(self.pending_points))[:n]
61+
points = list(
62+
set(range(self.n_requested + n))
63+
- set(self.data)
64+
- set(self.pending_points)
65+
)[:n]
6466

6567
loss_improvements = [self._loss_improvement(n) / n] * n
6668
if tell_pending:
@@ -76,7 +78,7 @@ def tell(self, n, value):
7678
self.data[n] = value
7779
self.pending_points.discard(n)
7880
self.sum_f += value
79-
self.sum_f_sq += value**2
81+
self.sum_f_sq += value ** 2
8082
self.npoints += 1
8183

8284
def tell_pending(self, n):
@@ -94,7 +96,7 @@ def std(self):
9496
n = self.npoints
9597
if n < 2:
9698
return np.inf
97-
numerator = self.sum_f_sq - n * self.mean**2
99+
numerator = self.sum_f_sq - n * self.mean ** 2
98100
if numerator < 0:
99101
# in this case the numerator ~ -1e-15
100102
return 0
@@ -109,8 +111,9 @@ def loss(self, real=True, *, n=None):
109111
if n < 2:
110112
return np.inf
111113
standard_error = self.std / sqrt(n)
112-
return max(standard_error / self.atol,
113-
standard_error / abs(self.mean) / self.rtol)
114+
return max(
115+
standard_error / self.atol, standard_error / abs(self.mean) / self.rtol
116+
)
114117

115118
def _loss_improvement(self, n):
116119
loss = self.loss()

adaptive/learner/balancing_learner.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# -*- coding: utf-8 -*-
22

3-
import os.path
43
from collections import defaultdict
54
from collections.abc import Iterable
65
from contextlib import suppress
@@ -69,7 +68,7 @@ class BalancingLearner(BaseLearner):
6968
behave in an undefined way. Change the `strategy` in that case.
7069
"""
7170

72-
def __init__(self, learners, *, cdims=None, strategy='loss_improvements'):
71+
def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
7372
self.learners = learners
7473

7574
# Naively we would make 'function' a method, but this causes problems
@@ -83,8 +82,9 @@ def __init__(self, learners, *, cdims=None, strategy='loss_improvements'):
8382
self._cdims_default = cdims
8483

8584
if len({learner.__class__ for learner in self.learners}) > 1:
86-
raise TypeError('A BalacingLearner can handle only one type'
87-
' of learners.')
85+
raise TypeError(
86+
"A BalacingLearner can handle only one type" " of learners."
87+
)
8888

8989
self.strategy = strategy
9090

@@ -101,16 +101,17 @@ def strategy(self):
101101
@strategy.setter
102102
def strategy(self, strategy):
103103
self._strategy = strategy
104-
if strategy == 'loss_improvements':
104+
if strategy == "loss_improvements":
105105
self._ask_and_tell = self._ask_and_tell_based_on_loss_improvements
106-
elif strategy == 'loss':
106+
elif strategy == "loss":
107107
self._ask_and_tell = self._ask_and_tell_based_on_loss
108-
elif strategy == 'npoints':
108+
elif strategy == "npoints":
109109
self._ask_and_tell = self._ask_and_tell_based_on_npoints
110110
else:
111111
raise ValueError(
112112
'Only strategy="loss_improvements", strategy="loss", or'
113-
' strategy="npoints" is implemented.')
113+
' strategy="npoints" is implemented.'
114+
)
114115

115116
def _ask_and_tell_based_on_loss_improvements(self, n):
116117
selected = [] # tuples ((learner_index, point), loss_improvement)
@@ -120,17 +121,14 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
120121
for index, learner in enumerate(self.learners):
121122
# Take the points from the cache
122123
if index not in self._ask_cache:
123-
self._ask_cache[index] = learner.ask(
124-
n=1, tell_pending=False)
124+
self._ask_cache[index] = learner.ask(n=1, tell_pending=False)
125125
points, loss_improvements = self._ask_cache[index]
126126
to_select.append(
127-
((index, points[0]),
128-
(loss_improvements[0], -total_points[index]))
127+
((index, points[0]), (loss_improvements[0], -total_points[index]))
129128
)
130129

131130
# Choose the optimal improvement.
132-
(index, point), (loss_improvement, _) = max(
133-
to_select, key=itemgetter(1))
131+
(index, point), (loss_improvement, _) = max(to_select, key=itemgetter(1))
134132
total_points[index] += 1
135133
selected.append(((index, point), loss_improvement))
136134
self.tell_pending((index, point))
@@ -139,13 +137,12 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
139137
return points, loss_improvements
140138

141139
def _ask_and_tell_based_on_loss(self, n):
142-
selected = [] # tuples ((learner_index, point), loss_improvement)
140+
selected = [] # tuples ((learner_index, point), loss_improvement)
143141
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
144142
for _ in range(n):
145143
losses = self._losses(real=False)
146144
index, _ = max(
147-
enumerate(zip(losses, (-n for n in total_points))),
148-
key=itemgetter(1)
145+
enumerate(zip(losses, (-n for n in total_points))), key=itemgetter(1)
149146
)
150147
total_points[index] += 1
151148

@@ -257,7 +254,7 @@ def plot(self, cdims=None, plotter=None, dynamic=True):
257254
cdims = cdims or self._cdims_default
258255

259256
if cdims is None:
260-
cdims = [{'i': i} for i in range(len(self.learners))]
257+
cdims = [{"i": i} for i in range(len(self.learners))]
261258
elif not isinstance(cdims[0], dict):
262259
# Normalize the format
263260
keys, values_list = cdims

adaptive/learner/base_learner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ def uses_nth_neighbors(n):
5353
...
5454
... return loss
5555
"""
56+
5657
def _wrapped(loss_per_interval):
5758
loss_per_interval.nth_neighbors = n
5859
return loss_per_interval
60+
5961
return _wrapped
6062

6163

0 commit comments

Comments
 (0)