Skip to content

Commit c8ba264

Browse files
bowiechenfacebook-github-bot
authored andcommitted
Update pyfmt component on FBS:master (facebookresearch#824)
Summary: X-link: meta-pytorch/tritonbench#661 X-link: facebook/Ax#4572 X-link: facebookexternal/aepsych_prerelease#42 X-link: facebook/dotslash#89 X-link: meta-pytorch/torchx#1165 X-link: meta-pytorch/botorch#3088 Differential Revision: D87671961
1 parent 8f68733 commit c8ba264

File tree

19 files changed

+275
-190
lines changed

19 files changed

+275
-190
lines changed

aepsych/acquisition/lookahead.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ def __init__(
272272
"""
273273
super().__init__(model=model, target=target, lookahead_type=lookahead_type)
274274
self.posterior_transform = posterior_transform
275-
assert (
276-
Xq is not None or query_set_size is not None
277-
), "Must pass either query set size or a query set!"
275+
assert Xq is not None or query_set_size is not None, (
276+
"Must pass either query set size or a query set!"
277+
)
278278
if Xq is not None and query_set_size is not None:
279279
assert Xq.shape[0] == query_set_size, (
280280
"If passing both Xq and query_set_size,"
@@ -360,9 +360,9 @@ def __init__(
360360
query_set_size (int, optional): Number of points in the query set.
361361
Xq (torch.Tensor, optional): (m x d) global reference set.
362362
"""
363-
assert (
364-
lookahead_type == "levelset"
365-
), f"ApproxGlobalSUR only supports lookahead on level set, got {lookahead_type}!"
363+
assert lookahead_type == "levelset", (
364+
f"ApproxGlobalSUR only supports lookahead on level set, got {lookahead_type}!"
365+
)
366366
super().__init__(
367367
lb=lb,
368368
ub=ub,

aepsych/benchmark/pathos_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def run_benchmarks_with_checkpoints(
264264
temp_results["rep"] = temp_results["rep"] + n_reps_per_chunk * chunk
265265
temp_results.to_csv(intermediate_fname)
266266
print(
267-
f"Collate done in {time.time()-collate_start} seconds, {len(bench.futures)}/{bench.num_benchmarks} left"
267+
f"Collate done in {time.time() - collate_start} seconds, {len(bench.futures)}/{bench.num_benchmarks} left"
268268
)
269269

270270
print(f"{benchmark_name} chunk {chunk} fully done!")

aepsych/benchmark/problem.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def evaluate(
155155
# always eval f
156156
f_hat = self.f_hat(model)
157157
p_hat = self.p_hat(model)
158-
assert (
159-
self.f_true.shape == f_hat.shape
160-
), f"self.f_true.shape=={self.f_true.shape} != f_hat.shape=={f_hat.shape}"
158+
assert self.f_true.shape == f_hat.shape, (
159+
f"self.f_true.shape=={self.f_true.shape} != f_hat.shape=={f_hat.shape}"
160+
)
161161

162162
mae_f = torch.mean(torch.abs(self.f_true - f_hat))
163163
mse_f = torch.mean((self.f_true - f_hat) ** 2)

aepsych/factory/default.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,9 @@ def default_mean_covar_factory(
282282
stacklevel=2,
283283
)
284284

285-
assert (config is not None) or (
286-
dim is not None
287-
), "Either config or dim must be provided!"
285+
assert (config is not None) or (dim is not None), (
286+
"Either config or dim must be provided!"
287+
)
288288

289289
assert stimuli_per_trial in (1, 2), "stimuli_per_trial must be 1 or 2!"
290290

aepsych/factory/pairwise.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def pairwise_mean_covar_factory(
125125
stacklevel=2,
126126
)
127127

128-
assert (
129-
stimuli_per_trial == 1
130-
), f"pairwise_mean_covar_factory must have stimuli_per_trial == 1, but {stimuli_per_trial} was passed instead!"
128+
assert stimuli_per_trial == 1, (
129+
f"pairwise_mean_covar_factory must have stimuli_per_trial == 1, but {stimuli_per_trial} was passed instead!"
130+
)
131131
lb = config.gettensor("common", "lb")
132132
ub = config.gettensor("common", "ub")
133133
assert lb.shape[0] == ub.shape[0], "bounds shape mismatch!"
@@ -162,9 +162,9 @@ def pairwise_mean_covar_factory(
162162

163163
if len(shared_dims) > 0:
164164
active_dims = [i for i in range(config_dim) if i not in shared_dims]
165-
assert (
166-
len(active_dims) % 2 == 0
167-
), "dimensionality of non-shared dims must be even!"
165+
assert len(active_dims) % 2 == 0, (
166+
"dimensionality of non-shared dims must be even!"
167+
)
168168
mean = _get_default_mean_function(config, zero_mean)
169169
cov1 = _get_default_cov_function(
170170
config, len(active_dims) // 2, stimuli_per_trial=1

aepsych/generators/acqf_grid_search_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,5 @@ def _gen(
4747
_, idxs = torch.topk(acqf_vals, num_points)
4848
new_candidate = grid[idxs]
4949

50-
logger.info(f"Gen done, time={time.time()-starttime}")
50+
logger.info(f"Gen done, time={time.time() - starttime}")
5151
return new_candidate

aepsych/generators/acqf_thompson_sampler_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ def _gen(
5656
)
5757
new_candidate = grid[candidate_idx]
5858

59-
logger.info(f"Gen done, time={time.time()-starttime}")
59+
logger.info(f"Gen done, time={time.time() - starttime}")
6060
return new_candidate

aepsych/models/semi_p.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,9 @@ def __init__(
308308
)
309309

310310
likelihood = likelihood or LinearBernoulliLikelihood()
311-
assert isinstance(
312-
likelihood, LinearBernoulliLikelihood
313-
), "SemiP model only supports linear Bernoulli likelihoods!"
311+
assert isinstance(likelihood, LinearBernoulliLikelihood), (
312+
"SemiP model only supports linear Bernoulli likelihoods!"
313+
)
314314

315315
super().__init__(
316316
dim=dim,

aepsych/plotting.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -745,9 +745,9 @@ def plot_strat(
745745
DeprecationWarning,
746746
stacklevel=2,
747747
)
748-
assert (
749-
"binary" in strat.outcome_types
750-
), f"Plotting not supported for outcome_type {strat.outcome_types[0]}"
748+
assert "binary" in strat.outcome_types, (
749+
f"Plotting not supported for outcome_type {strat.outcome_types[0]}"
750+
)
751751

752752
if target_level is not None and not hasattr(strat.model, "monotonic_idxs"):
753753
warnings.warn(
@@ -873,7 +873,7 @@ def _plot_strat_1d(
873873
alpha=0.3,
874874
hatch="///",
875875
edgecolor="gray",
876-
label=f"{cred_level*100:.0f}% posterior mass",
876+
label=f"{cred_level * 100:.0f}% posterior mass",
877877
)
878878
if target_level is not None:
879879
from aepsych.utils import interpolate_monotonic
@@ -892,7 +892,7 @@ def _plot_strat_1d(
892892
xerr=np.r_[thresh_med - thresh_lower, thresh_upper - thresh_med][:, None],
893893
capsize=5,
894894
elinewidth=1,
895-
label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass marked)",
895+
label=f"Est. {target_level * 100:.0f}% threshold \n(with {cred_level * 100:.0f}% posterior \nmass marked)",
896896
)
897897

898898
if true_testfun is not None:
@@ -911,7 +911,7 @@ def _plot_strat_1d(
911911
true_thresh,
912912
target_level,
913913
"o",
914-
label=f"True {target_level*100:.0f}% threshold",
914+
label=f"True {target_level * 100:.0f}% threshold",
915915
)
916916

917917
ax.scatter(
@@ -1031,7 +1031,7 @@ def _plot_strat_2d(
10311031
ax.plot(
10321032
context_grid,
10331033
thresh_75.cpu().numpy(),
1034-
label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass shaded)",
1034+
label=f"Est. {target_level * 100:.0f}% threshold \n(with {cred_level * 100:.0f}% posterior \nmass shaded)",
10351035
)
10361036
ax.fill_between(
10371037
context_grid,

aepsych/strategy/sequential.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def _make_next_strat(self) -> None:
7070
return
7171

7272
# populate new model with final data from last model
73-
assert (
74-
self.x is not None and self.y is not None
75-
), "Cannot initialize next strategy; no data has been given!"
73+
assert self.x is not None and self.y is not None, (
74+
"Cannot initialize next strategy; no data has been given!"
75+
)
7676
self.strat_list[self._strat_idx + 1].add_data(self.x, self.y)
7777

7878
self._suggest_count = 0
@@ -146,9 +146,9 @@ def get_config_options(
146146
strat_names = config.getlist("common", "strategy_names", element_type=str)
147147

148148
# ensure strat_names are unique
149-
assert len(strat_names) == len(
150-
set(strat_names)
151-
), f"Strategy names {strat_names} are not all unique!"
149+
assert len(strat_names) == len(set(strat_names)), (
150+
f"Strategy names {strat_names} are not all unique!"
151+
)
152152

153153
strats = []
154154
for name in strat_names:

0 commit comments

Comments
 (0)