Skip to content

Commit e0439d2

Browse files
authored
Fixed python code formatting and added flake8 setup (#346)
* Fixed python code formatting and added flake8 setup * Fixes config.yaml * Added missing setup.cfg * Removed flake8 job from circle ci and setup GHA * More flake8 fixes * Fixed test_conv2d * Fixed failing flake8
1 parent 0b15a12 commit e0439d2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+837
-475
lines changed

.github/workflows/lint.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: Lint
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
9+
jobs:
10+
flake8-py3:
11+
runs-on: ubuntu-18.04
12+
steps:
13+
- uses: actions/checkout@v2
14+
- name: Setup Python
15+
uses: actions/setup-python@v2
16+
with:
17+
python-version: 3.x
18+
architecture: x64
19+
- name: Install dependencies
20+
run: |
21+
set -eux
22+
pip3 install flake8 --user
23+
flake8 --version
24+
- name: Run flake8
25+
run: |
26+
set -eux
27+
flake8 .
28+
29+
concurrency:
30+
group: lint-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
31+
cancel-in-progress: true

codegen/codegen_outofplacebatching.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,48 +6,55 @@
66

77
import argparse
88
from typing import Tuple, List
9-
from collections import defaultdict
109
import re
1110

11+
1212
def num_leading_spaces(line: str) -> int:
1313
return len(line) - len(line.lstrip())
1414

15+
1516
def min_leading_spaces(lines):
1617
num_spaces = [num_leading_spaces(line) for line in lines if len(line) > 0]
1718
if len(num_spaces) == 0:
1819
return None
1920
return min(num_spaces)
2021

22+
2123
def deindent(code: str) -> str:
2224
lines = code.split('\n')
2325
mls = min_leading_spaces(lines)
2426
lines = [line[mls:] for line in lines]
2527
return '\n'.join(lines)
2628

29+
2730
def indent(code: str, num) -> str:
2831
lines = code.split('\n')
2932
indented_lines = [' ' * num + line for line in lines]
3033
indented_lines[0] = lines[0]
3134
return '\n'.join(indented_lines)
3235

36+
3337
def is_tensor(typ: str) -> bool:
3438
if typ == 'Tensor':
3539
return True
3640
if typ == 'const Tensor &':
3741
return True
3842
return False
3943

44+
4045
def is_optional_tensor(typ: str) -> bool:
4146
if typ == 'c10::optional<Tensor>':
4247
return True
4348
if typ == 'const c10::optional<Tensor> &':
4449
return True
4550
return False
4651

52+
4753
def is_vector_tensor(typ: str) -> bool:
4854
# (chilli): I don't really understand why there's 2 dots in front?
4955
return (typ == '::std::vector<Tensor>')
5056

57+
5158
def add_bdim_after_tensor(types: Tuple[str]) -> Tuple[str]:
5259
result = []
5360
for typ in types:
@@ -56,6 +63,7 @@ def add_bdim_after_tensor(types: Tuple[str]) -> Tuple[str]:
5663
result.append('c10::optional<int64_t>')
5764
return tuple(result)
5865

66+
5967
def batch_rule_type(
6068
op_returns: Tuple[str],
6169
op_args: Tuple[str],
@@ -67,13 +75,15 @@ def batch_rule_type(
6775
result = f"typedef std::tuple<{','.join(returns)}> (*{br_t})({', '.join(args)});"
6876
return result, br_t
6977

78+
7079
def unwrap_tensor(name: str) -> List[str]:
7180
result = f"""\
7281
Tensor {name}_value;
7382
optional<int64_t> {name}_bdim;
7483
std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}, cur_level);"""
7584
return deindent(result).split('\n')
7685

86+
7787
def unwrap_optional_tensor(name: str) -> List[str]:
7888
result = f"""\
7989
optional<Tensor> {name}_value;
@@ -83,6 +93,7 @@ def unwrap_optional_tensor(name: str) -> List[str]:
8393
}}"""
8494
return deindent(result).split('\n')
8595

96+
8697
def gen_unwraps(arg_types, arg_names):
8798
tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
8899
optional_tensors = [name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)]
@@ -103,6 +114,7 @@ def gen_unwraps(arg_types, arg_names):
103114
unwrapped_arg_list.append(arg)
104115
return unwraps, unwrapped_arg_list
105116

117+
106118
def lower(returns: Tuple[str], args: List[Tuple[str, str]], unique_count: int, ops) -> str:
107119
arg_types, arg_names = zip(*args)
108120
batch_rule_typedef, batch_rule_t = batch_rule_type(returns, arg_types, unique_count)
@@ -120,7 +132,9 @@ def lower(returns: Tuple[str], args: List[Tuple[str, str]], unique_count: int, o
120132
wrapped_returns.append(f'makeBatched(std::get<{idx}>(results), std::get<{idx + 1}>(results), cur_level)')
121133
idx += 2
122134
elif is_vector_tensor(ret):
123-
wrapped_returns.append(f'makeBatchedVector(std::get<{idx}>(results), std::get<{idx + 1}>(results), cur_level)')
135+
wrapped_returns.append(
136+
f'makeBatchedVector(std::get<{idx}>(results), std::get<{idx + 1}>(results), cur_level)'
137+
)
124138
idx += 2
125139
else:
126140
wrapped_returns.append(f'std::get<{idx}>(results)')
@@ -148,6 +162,7 @@ def lower(returns: Tuple[str], args: List[Tuple[str, str]], unique_count: int, o
148162
}}"""
149163
return deindent(result)
150164

165+
151166
def parse_return(return_t):
152167
if 'std::tuple' not in return_t:
153168
return (return_t,)
@@ -156,6 +171,7 @@ def parse_return(return_t):
156171
m = re.match(r'::std::tuple<(.*)>', return_t)
157172
return tuple([x.strip() for x in m.group(1).split(',')])
158173

174+
159175
def parse_args(args_t):
160176
# There is an assumption made that args are separated with comma-space
161177
# and types like std::array<bool,2> do not contain spaces after the comma
@@ -166,6 +182,7 @@ def parse_args(args_t):
166182
result.append((arg[:split_idx].strip(), arg[split_idx:].strip()))
167183
return tuple(result)
168184

185+
169186
def get_signatures(path='build/aten/src/ATen/RegistrationDeclarations.h', include_op=False):
170187
with open(path, 'r') as f:
171188
txt = f.read()
@@ -188,6 +205,7 @@ def get_signatures(path='build/aten/src/ATen/RegistrationDeclarations.h', includ
188205
schemas.append(result)
189206
return tuple(schemas)
190207

208+
191209
def is_schema_outplace(schema):
192210
_, returns, args = schema
193211
for arg in args:
@@ -207,16 +225,19 @@ def is_schema_outplace(schema):
207225
return False
208226
return True
209227

228+
210229
def get_hash(schema):
211230
ret_t, args = schema
212231
args_t, _ = tuple(zip(*args))
213232
return (ret_t, args_t)
214233

234+
215235
class Container:
216236
def __init__(self, schema, ops):
217237
self.schema = schema
218238
self.ops = ops
219239

240+
220241
if __name__ == '__main__':
221242
parser = argparse.ArgumentParser()
222243
parser.add_argument('path',

codegen/gen_plumbing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
import re
32

43
from codegen_outofplacebatching import deindent, get_signatures, gen_unwraps
54

examples/compilation/eager_fusion.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from functorch import compiled_function, tvm_compile
2-
import torch.nn as nn
32
import torch
43
from functools import partial
54
import time
65
import torch.utils
76

87
a = torch.randn(2000, 1, 4, requires_grad=True)
98
b = torch.randn(1, 2000, 4)
9+
10+
1011
def f(a):
1112
return (a * b).sum(dim=0)
1213

14+
1315
fw_compiler = partial(tvm_compile, name='fw_keops')
1416
bw_compiler = partial(tvm_compile, name='bw_keops')
1517
compiled_f = compiled_function(f, fw_compiler, bw_compiler)
@@ -19,30 +21,35 @@ def f(a):
1921
iters = 10
2022
out = compiled_f(a)
2123
out.sum().backward()
24+
25+
2226
def bench(func):
2327
begin = time.time()
2428
for _ in range(iters):
2529
out = func(a).sin()
2630
out.sum().backward()
2731
a.grad = None
28-
print(time.time()-begin)
32+
print(time.time() - begin)
33+
2934

3035
def bench_jax():
3136
import jax.numpy as jnp
3237
import jax
3338
jax_a = jnp.array(a.detach().numpy())
3439
jax_b = jnp.array(b.detach().numpy())
40+
3541
def f(a):
36-
return jnp.sin((a*jax_b).sum(axis=[0])).sum()
42+
return jnp.sin((a * jax_b).sum(axis=[0])).sum()
3743
jit_f = jax.jit(jax.grad(f))
3844
jit_f(jax_a)
3945
begin = time.time()
4046
for _ in range(iters):
4147
out = jit_f(jax_a)
4248
out.block_until_ready()
43-
print(time.time()-begin)
49+
print(time.time() - begin)
4450
# for
4551

52+
4653
bench(f)
4754
bench(compiled_f)
4855
# bench_jax()

examples/compilation/fuse_module.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
1+
import timeit
12
from functorch import compiled_module, tvm_compile
23
import torch.nn as nn
34
import torch
45
from functools import partial
56

7+
68
def nop(f, _):
79
return f
810

11+
912
fw_compiler = partial(tvm_compile, name='fw_keops')
1013
bw_compiler = partial(tvm_compile, name='bw_keops')
1114
fw_compiler = nop
1215
bw_compiler = nop
1316

17+
1418
def run(mod, input):
1519
out = mod(input)
1620
out.sum().backward()
1721
grads = [p.grad for p in mod.parameters()]
1822
return (out, *grads)
1923

24+
2025
class Foo(nn.Module):
2126
def __init__(self):
2227
super(Foo, self).__init__()
@@ -26,6 +31,7 @@ def __init__(self):
2631
def forward(self, x):
2732
return (self.param * x + self.buf).sum(dim=0)
2833

34+
2935
input = torch.randn(1)
3036
mod = Foo()
3137
compiled_mod = compiled_module(mod, fw_compiler, bw_compiler)
@@ -42,7 +48,6 @@ def forward(self, x):
4248
for a, b in zip(run(mod, input), run(compiled_mod, input)):
4349
torch.testing.assert_allclose(a, b)
4450

45-
import timeit
4651
for _ in range(5):
4752
i = 10000
4853
t = timeit.Timer("mod(input)", globals=globals()).timeit(10000)

examples/compilation/linear_train.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,20 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from functorch import grad, vmap, pythonkey_trace, wrap_key, make_fx, nnc_jit, make_functional, grad_and_value
7+
from functorch import nnc_jit, make_functional
88
import torch
9-
import torch.fx as fx
109
import torch.nn as nn
1110
import time
1211
torch._C._jit_override_can_fuse_on_cpu(True)
1312

13+
1414
def bench(f, iters=100, warmup=10):
1515
for _ in range(warmup):
1616
f()
1717
begin = time.time()
1818
for _ in range(iters):
1919
f()
20-
print((time.time()-begin))
20+
print((time.time() - begin))
2121

2222

2323
class Foo(nn.Module):
@@ -42,23 +42,28 @@ def forward(self, x):
4242
jit_mod = torch.jit.script(mod)
4343

4444
func_model, weights = make_functional(mod)
45-
lr =1.0
45+
lr = 1.0
46+
4647

4748
def functional_step(x, weights):
4849
weights = [weight.detach().requires_grad_() for weight in weights]
4950
out = func_model(weights, x)
5051
out.backward()
51-
new_weights = [weight - lr*weight.grad for weight in weights]
52+
new_weights = [weight - lr * weight.grad for weight in weights]
5253
return out, new_weights
5354

54-
optim = torch.optim.SGD(jit_mod.parameters(), lr=lr, momentum=0,dampening=0, weight_decay=0)
55+
56+
optim = torch.optim.SGD(jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0)
57+
58+
5559
def jit_step(x, weights):
5660
optim.zero_grad()
5761
loss = jit_mod(x)
5862
loss.backward()
5963
optim.step()
6064
return loss, None
6165

66+
6267
def train(train_step, weights):
6368
torch.manual_seed(16)
6469
train_step(inp, weights)
@@ -67,9 +72,10 @@ def train(train_step, weights):
6772
loss, weights = train_step(torch.randn(batch_size, features), weights)
6873
if itr % 200 == 0:
6974
print(f"Loss at {itr}: {loss}")
70-
print("Time taken: ", time.time()-begin)
75+
print("Time taken: ", time.time() - begin)
7176
print()
7277

78+
7379
grad_pt = functional_step
7480
grad_nnc = nnc_jit(functional_step)
7581

examples/compilation/simple_function.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,19 @@
88
import torch
99
import time
1010

11+
1112
def f(x):
1213
return torch.sin(x).sum()
1314

15+
1416
inp = torch.randn(100)
1517
grad_pt = grad(f)
1618
grad_fx = make_fx(grad_pt)(inp)
1719
grad_nnc = nnc_jit(grad_pt, skip_specialization=True)
1820
loopnest = make_nnc(grad_pt)(inp)
1921
print(loopnest)
2022

23+
2124
def bench(name, f, iters=10000, warmup=3):
2225
for _ in range(warmup):
2326
f()
@@ -26,6 +29,7 @@ def bench(name, f, iters=10000, warmup=3):
2629
f()
2730
print(f"{name}: ", time.time() - begin)
2831

32+
2933
bench("Pytorch: ", lambda: grad_pt(inp))
3034
bench("FX: ", lambda: grad_fx(inp))
3135
bench("NNC: ", lambda: grad_nnc(inp))

0 commit comments

Comments
 (0)