Skip to content

Commit 308f50b

Browse files
authored
Align functorch's flake8 config with pytorch's (#963)
1 parent d40a496 commit 308f50b

File tree

11 files changed

+40
-35
lines changed

11 files changed

+40
-35
lines changed

.flake8

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[flake8]
2+
select = B,C,E,F,P,T4,W,B9
3+
max-line-length = 120
4+
# C408 ignored because we like the dict keyword argument syntax
5+
# E501 is not flexible enough, we're using B950 instead
6+
ignore =
7+
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
8+
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
9+
# to line this up with executable bit
10+
EXE001,
11+
# these ignores are from flake8-bugbear; please fix!
12+
B007,B008,
13+
# these ignores are from flake8-comprehensions; please fix!
14+
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415

codegen/gen.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import os
22
import argparse
3-
import pathlib
4-
5-
import torchgen
63
from torchgen.gen import FileManager, parse_native_yaml
74
from torchgen.gen import get_torchgen_root
85
from gen_vmap_plumbing import gen_all_vmap_plumbing

examples/compilation/fuse_module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from functorch.compile import compiled_module, tvm_compile
33
import torch.nn as nn
44
import torch
5-
from functools import partial
65

76

87
def nop(f, _):

functorch/_src/partitioners.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def classify_nodes(joint_module):
269269

270270
aten = torch.ops.aten
271271

272-
pointwise_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward] # noqa: E501
272+
pointwise_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward] # noqa: E501
273273
misc_ops = [aten.to, aten.type_as, operator.getitem]
274274

275275
reduction_ops = [aten.softmax, aten._softmax, aten._softmax_backward_data, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax] # noqa: E501
@@ -338,27 +338,27 @@ def get_node_weight(node):
338338
continue
339339

340340
if node in required_bw_nodes:
341-
nx_graph.add_edge(node.name+"_in", "sink", capacity=math.inf)
341+
nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf)
342342
continue
343343

344344
if node.op == 'placeholder' and "primals" in node.target:
345-
nx_graph.add_edge("source", node.name+"_in", capacity=math.inf)
345+
nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)
346346

347347
# If a node can't be recomputed (too expensive or involves randomness),
348348
# we prevent it from being recomputed by adding an inf edge to the source
349349
# We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.
350350
if ban_recomputation(node) and node in required_fw_nodes:
351-
nx_graph.add_edge("source", node.name+"_in", capacity=math.inf)
351+
nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)
352352

353353
if 'tensor_meta' not in node.meta:
354354
weight = math.inf
355355
else:
356356
weight = get_node_weight(node)
357357

358358
# Creates the weights on the "node" edge
359-
nx_graph.add_edge(node.name+"_in", node.name+"_out", capacity=weight)
359+
nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight)
360360
for user in node.users:
361-
nx_graph.add_edge(node.name+"_out", user.name+"_in", capacity=math.inf)
361+
nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf)
362362

363363
cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
364364
reachable, non_reachable = partition

test/common_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
4545
# - Don't hash tensors in a global context, that'll keep them around forever
4646
def memoize(fn):
4747
memo = {}
48+
4849
def wrapped(*args):
4950
if args not in memo:
5051
memo[args] = fn(*args)
@@ -78,7 +79,7 @@ def get_bdim_choices_batch_norm(num_tensors, _, running_mean=None, running_var=N
7879
options = (-1, None)
7980

8081
# instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified
81-
if running_mean == None or running_var == None:
82+
if running_mean is None or running_var is None:
8283
choices.append((None,) + (0,) * (num_tensors - 1))
8384
for choice in itertools.product(options, repeat=num_tensors - 1):
8485
choices.append((None,) + choice)

test/test_eager_transforms.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def f(x):
365365
vjp_fn(result)
366366

367367
def test_conj_bit(self):
368-
x = torch.tensor(1+1j)
368+
x = torch.tensor(1 + 1j)
369369

370370
def foo(x):
371371
assert not x.is_conj()
@@ -1647,8 +1647,8 @@ def test_jacfwd_different_levels(self, device):
16471647
A = 0.1 * torch.randn(b, d, d, device=device)
16481648

16491649
def loss(A, x1, x2):
1650-
x2_hat = (A@(x1.T)).T
1651-
res = x2-x2_hat
1650+
x2_hat = (A @ (x1.T)).T
1651+
res = x2 - x2_hat
16521652
res_sqr = res**2
16531653
return res_sqr.sum()
16541654

@@ -2528,7 +2528,7 @@ def inner_loss(params, x1, y1):
25282528
else:
25292529
loss = inner_loss(params, x1, y1)
25302530
grads = torch.autograd.grad(loss, params, create_graph=True)
2531-
new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]
2531+
new_params = [(params[i] - alpha * grads[i]) for i in range(len(params))]
25322532

25332533
v_f = net(new_params, x2)
25342534
return mse_loss(v_f, y2)
@@ -2537,15 +2537,15 @@ def inner_loss(params, x1, y1):
25372537

25382538
# Compute with vmap+grad
25392539
inner_losses = vmap(partial(get_loss_for_task, True))(task[0], task[1], task[2], task[3])
2540-
loss2 = sum(inner_losses)/len(inner_losses)
2540+
loss2 = sum(inner_losses) / len(inner_losses)
25412541
result_grads = torch.autograd.grad(loss2, params)
25422542

25432543
# Compute without vmap+grad
25442544
inner_losses = [
25452545
get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i])
25462546
for i in range(num_tasks)
25472547
]
2548-
loss2 = sum(inner_losses)/len(inner_losses)
2548+
loss2 = sum(inner_losses) / len(inner_losses)
25492549
expected_grads = torch.autograd.grad(loss2, params)
25502550

25512551
self.assertEqual(result_grads, expected_grads)

test/test_functionalize.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
import torch
2-
31
import functorch
4-
from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS
5-
import unittest
62
from unittest.mock import patch
73
import functools
8-
9-
from functorch.compile import aot_function, nop
4+
from torch.testing._internal.common_utils import run_tests
105
import test_compile_cache
116
import test_pythonkey
127

@@ -41,7 +36,7 @@ class FunctionalizeTest(cls):
4136
return FunctionalizeTest
4237

4338

44-
FunctionalizeTestCompileCache = make_functionalize_test(test_compile_cache.TestCompileCache)
39+
FunctionalizeTestCompileCache = make_functionalize_test(test_compile_cache.TestCompileCache)
4540
FunctionalizeTestCompileCacheStaticArgs = make_functionalize_test(test_compile_cache.TestCompileCacheStaticArgs)
4641
FunctionalizeTestPythonKeyAOT = make_functionalize_test(test_pythonkey.TestAOTAutograd)
4742
FunctionalizeTestPythonKeyContiguous = make_functionalize_test(test_pythonkey.TestContiguous)

test/test_memory_efficient_fusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def layer_norm(x, weight, bias):
141141
mean = torch.mean(x, dim, keepdim=True)
142142
centered = x - mean
143143
var = torch.sum(centered * centered, dim, keepdim=True) / x.size(-1)
144-
rvar = 1./torch.sqrt(var+eps)
145-
normed = (x-mean) * rvar
144+
rvar = 1. / torch.sqrt(var + eps)
145+
normed = (x - mean) * rvar
146146
return normed * weight + bias
147147

148148
bs = 10

test/test_minifier.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from functorch.compile import minifier
33
from functorch import make_fx
44
from torch.testing._internal.common_utils import TestCase, run_tests
5-
import unittest
65

76

87
class TestMinifier(TestCase):

test/test_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1414
from torch.testing._internal.common_device_type import ops
1515
from torch.testing._internal.common_device_type import \
16-
toleranceOverride, tol
16+
toleranceOverride, tol
1717
from functorch_lagging_op_db import functorch_lagging_op_db
1818
from functorch_additional_op_db import additional_op_db
1919
from common_utils import (
@@ -28,7 +28,6 @@
2828
opsToleranceOverride,
2929
check_vmap_fallback,
3030
)
31-
import unittest
3231
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
3332
from functorch import grad, vjp, vmap, jacrev, jacfwd
3433
import torch.autograd.forward_ad as fwAD
@@ -773,8 +772,8 @@ def test_vmapjvpall(self, device, dtype, op):
773772
xfail('nn.functional.bilinear'), # trilinear doesn't have batching rule
774773
xfail('linalg.eigh'), # _linalg_eigh doesn't have batching rule
775774
xfail('linalg.eigvalsh'), # _linalg_eigh doesn't have batching rule
776-
xfail('logdet'), # _linalg_slogdet doesn't have batching rule
777-
xfail('linalg.slogdet'), # _linalg_slogdet doesn't have batching rule
775+
xfail('logdet'), # _linalg_slogdet doesn't have batching rule
776+
xfail('linalg.slogdet'), # _linalg_slogdet doesn't have batching rule
778777
}))
779778
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
780779
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
@@ -1133,6 +1132,7 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
11331132
self.assertFalse(op.supports_fwgrad_bwgrad,
11341133
f"{op.name} now supports forward over reverse without a decomposition. " +
11351134
"Please remove the decomposition version")
1135+
11361136
def is_differentiable(t):
11371137
return isinstance(t, torch.Tensor) and t.dtype == torch.float32
11381138
args = (cotangents, *primals)
@@ -1148,7 +1148,7 @@ def is_differentiable(t):
11481148
self.assertEqual(result, expected)
11491149

11501150
def _make_extremal_inputs(self, shape, device):
1151-
if shape == None:
1151+
if shape is None:
11521152
return (None,)
11531153
return (
11541154
torch.full(shape, -1000., device=device),

0 commit comments

Comments
 (0)