Skip to content

Commit 250d317

Browse files
authored
ci: Remove yapf and isort in favor of ruff (#613)
1 parent 409de04 commit 250d317

33 files changed

+527
-267
lines changed

.github/labeler.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,14 @@ ci:
1212
- .github/**/*
1313
- codecov.yaml
1414
- .pre-commit-config.yaml
15-
- .ruff.toml
16-
- .style.yapf
17-
- setup.cfg
1815

1916
documentation:
2017
- any:
2118
- changed-files:
2219
- any-glob-to-any-file:
2320
- docs/**/*
2421
- readthedocs.yml
25-
- README.MD
22+
- README.md
2623

2724
benchmark:
2825
- any:

.github/workflows/aws/upload_final_index.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,40 @@
6969

7070
bucket.Object('index.html').upload_file('index.html', ExtraArgs=args)
7171

72-
index_html = html.format('\n'.join([
73-
href.format(f'{torch_version}.html'.replace('+', '%2B'), torch_version)
74-
for torch_version in wheels_dict
75-
]))
72+
index_html = html.format(
73+
'\n'.join(
74+
[
75+
href.format(
76+
f'{torch_version}.html'.replace('+', '%2B'),
77+
torch_version,
78+
)
79+
for torch_version in wheels_dict
80+
],
81+
),
82+
)
7683

7784
with open('index.html', 'w') as f:
7885
f.write(index_html)
7986

8087
bucket.Object('whl/index.html').upload_file('index.html', ExtraArgs=args)
8188

8289
for torch_version, wheels in wheels_dict.items():
83-
torch_version_html = html.format('\n'.join([
84-
href.format(f'{orig_torch_version}/{wheel}'.replace('+', '%2B'), wheel)
85-
for orig_torch_version, wheel in wheels
86-
]))
90+
torch_version_html = html.format(
91+
'\n'.join(
92+
[
93+
href.format(
94+
f'{orig_torch_version}/{wheel}'.replace('+', '%2B'),
95+
wheel,
96+
)
97+
for orig_torch_version, wheel in wheels
98+
],
99+
),
100+
)
87101

88102
with open(f'{torch_version}.html', 'w') as f:
89103
f.write(torch_version_html)
90104

91105
bucket.Object(f'whl/{torch_version}.html').upload_file(
92-
f'{torch_version}.html', ExtraArgs=args)
106+
f'{torch_version}.html',
107+
ExtraArgs=args,
108+
)

.github/workflows/aws/upload_nightly_index.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,35 @@
4343
if '2.9.0' in torch_version:
4444
wheels_dict[torch_version.replace('2.9.0', '2.9.1')].append(wheel)
4545

46-
index_html = html.format('\n'.join([
47-
href.format(f'{version}.html'.replace('+', '%2B'), version)
48-
for version in wheels_dict
49-
]))
46+
index_html = html.format(
47+
'\n'.join(
48+
[
49+
href.format(f'{version}.html'.replace('+', '%2B'), version)
50+
for version in wheels_dict
51+
],
52+
),
53+
)
5054

5155
with open('index.html', 'w') as f:
5256
f.write(index_html)
5357
bucket.Object('whl/nightly/index.html').upload_file('index.html', args)
5458

5559
for torch_version, wheel_names in wheels_dict.items():
56-
torch_version_html = html.format('\n'.join([
57-
href.format(f'{wheel_name}'.replace('+', '%2B'),
58-
wheel_name.split('/')[-1]) for wheel_name in wheel_names
59-
]))
60+
torch_version_html = html.format(
61+
'\n'.join(
62+
[
63+
href.format(
64+
f'{wheel_name}'.replace('+', '%2B'),
65+
wheel_name.split('/')[-1],
66+
)
67+
for wheel_name in wheel_names
68+
],
69+
),
70+
)
6071

6172
with open(f'{torch_version}.html', 'w') as f:
6273
f.write(torch_version_html)
6374
bucket.Object(f'whl/nightly/{torch_version}.html').upload_file(
64-
f'{torch_version}.html', args)
75+
f'{torch_version}.html',
76+
args,
77+
)

.pre-commit-config.yaml

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,30 +48,21 @@ repos:
4848
--in-place,
4949
]
5050

51-
- repo: https://github.com/google/yapf
52-
rev: v0.43.0
53-
hooks:
54-
- id: yapf
55-
name: Format code
56-
57-
- repo: https://github.com/pycqa/isort
58-
rev: 8.0.1
59-
hooks:
60-
- id: isort
61-
name: Sort imports
62-
6351
- repo: https://github.com/astral-sh/ruff-pre-commit
6452
rev: v0.15.7
6553
hooks:
66-
- id: ruff
67-
name: Ruff formatting
54+
- id: ruff-check
55+
name: ruff check
6856
args: [--fix, --exit-non-zero-on-fix]
57+
- id: ruff-format
58+
name: ruff format
6959

7060
- repo: https://github.com/PyCQA/flake8
7161
rev: 7.3.0
7262
hooks:
7363
- id: flake8
7464
name: Check PEP8
65+
args: [--extend-ignore=E203]
7566

7667
- repo: https://github.com/pre-commit/mirrors-clang-format
7768
rev: v22.1.1

benchmark/classes/hash_map.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
if __name__ == '__main__':
1010
parser = argparse.ArgumentParser()
1111
parser.add_argument('--device', type=str, default='cuda')
12-
parser.add_argument('--dtype', type=str, default='int',
13-
choices=['short', 'int', 'long'])
12+
parser.add_argument(
13+
'--dtype',
14+
type=str,
15+
default='int',
16+
choices=['short', 'int', 'long'],
17+
)
1418
parser.add_argument('--num_keys', type=int, default=10_000_000)
1519
parser.add_argument('--num_queries', type=int, default=1_000_000)
1620
args = parser.parse_args()
@@ -25,14 +29,19 @@
2529

2630
max_value = torch.iinfo(dtype).max
2731

28-
key1 = torch.randint(0, max_value, (args.num_keys, ), dtype=dtype,
29-
device=args.device).unique()
32+
key1 = torch.randint(
33+
0,
34+
max_value,
35+
(args.num_keys,),
36+
dtype=dtype,
37+
device=args.device,
38+
).unique()
3039
query1 = key1[torch.randperm(key1.size(0), device=args.device)]
31-
query1 = query1[:args.num_queries]
40+
query1 = query1[: args.num_queries]
3241

3342
key2 = torch.randperm(args.num_keys, dtype=dtype, device=args.device)
3443
query2 = torch.randperm(args.num_queries, dtype=dtype, device=args.device)
35-
query2 = query2[:args.num_queries]
44+
query2 = query2[: args.num_queries]
3645

3746
if key1.is_cuda:
3847
t_init = t_get = 0
@@ -77,10 +86,17 @@
7786
for i in range(num_warmups + num_steps):
7887
torch.cuda.synchronize()
7988
t_start = time.perf_counter()
80-
hash_map = torch.full((args.num_keys, ), fill_value=-1, dtype=dtype,
81-
device=args.device)
82-
hash_map[key2.long()] = torch.arange(args.num_keys, dtype=dtype,
83-
device=args.device)
89+
hash_map = torch.full(
90+
(args.num_keys,),
91+
fill_value=-1,
92+
dtype=dtype,
93+
device=args.device,
94+
)
95+
hash_map[key2.long()] = torch.arange(
96+
args.num_keys,
97+
dtype=dtype,
98+
device=args.device,
99+
)
84100
torch.cuda.synchronize()
85101
if i >= num_warmups:
86102
t_init += time.perf_counter() - t_start
@@ -99,8 +115,10 @@
99115
t_init = t_get = 0
100116
for i in range(num_warmups + num_steps):
101117
t_start = time.perf_counter()
102-
hash_map = pd.CategoricalDtype(categories=key1.numpy(),
103-
ordered=True)
118+
hash_map = pd.CategoricalDtype(
119+
categories=key1.numpy(),
120+
ordered=True,
121+
)
104122
if i >= num_warmups:
105123
t_init += time.perf_counter() - t_start
106124

benchmark/ops/sampled.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
if args.device == 'cpu':
1818
num_warmups, num_steps = num_warmups // 10, num_steps // 10
1919

20-
a_index = torch.randint(0, num_nodes, (num_edges, ), device=args.device)
21-
b_index = torch.randint(0, num_nodes, (num_edges, ), device=args.device)
20+
a_index = torch.randint(0, num_nodes, (num_edges,), device=args.device)
21+
b_index = torch.randint(0, num_nodes, (num_edges,), device=args.device)
2222
out_grad = torch.randn(num_edges, num_feats, device=args.device)
2323

2424
for fn in ['add', 'sub', 'mul', 'div']:

benchmark/ops/spline.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
try:
99
import torch_spline_conv
10+
1011
HAS_TORCH_SPLINE_CONV = True
1112
except ImportError:
1213
HAS_TORCH_SPLINE_CONV = False
@@ -33,10 +34,16 @@
3334
for degree in [1, 2, 3]:
3435
E, D = 10000, 3
3536
pseudo = torch.rand(E, D, dtype=torch.float, device=args.device)
36-
kernel_size = torch.tensor([5] * D, dtype=torch.long,
37-
device=args.device)
38-
is_open_spline = torch.tensor([1] * D, dtype=torch.uint8,
39-
device=args.device)
37+
kernel_size = torch.tensor(
38+
[5] * D,
39+
dtype=torch.long,
40+
device=args.device,
41+
)
42+
is_open_spline = torch.tensor(
43+
[1] * D,
44+
dtype=torch.uint8,
45+
device=args.device,
46+
)
4047
label = f'spline_basis (degree={degree}, E={E}, D={D})'
4148

4249
if bench_original:
@@ -114,8 +121,13 @@
114121
x = torch.randn(E, M_in, dtype=torch.float, device=args.device)
115122
weight = torch.randn(K, M_in, M_out, dtype=torch.float, device=args.device)
116123
basis = torch.rand(E, S, dtype=torch.float, device=args.device)
117-
weight_index = torch.randint(0, K, (E, S), dtype=torch.long,
118-
device=args.device)
124+
weight_index = torch.randint(
125+
0,
126+
K,
127+
(E, S),
128+
dtype=torch.long,
129+
device=args.device,
130+
)
119131
label = f'spline_weighting (E={E}, M_in={M_in}, M_out={M_out}, K={K})'
120132

121133
if bench_original:

benchmark/sampler/hetero_neighbor.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,48 @@
1414
from pyg_lib.testing import remap_keys, withDataset, withSeed
1515

1616
argparser = argparse.ArgumentParser('Hetero neighbor sample benchmark')
17-
argparser.add_argument('--batch-sizes', nargs='+', type=int, default=[
18-
512,
19-
1024,
20-
2048,
21-
4096,
22-
8192,
23-
])
17+
argparser.add_argument(
18+
'--batch-sizes',
19+
nargs='+',
20+
type=int,
21+
default=[
22+
512,
23+
1024,
24+
2048,
25+
4096,
26+
8192,
27+
],
28+
)
2429

2530
# TODO (kgajdamo): Support undirected hetero graphs
2631
# argparser.add_argument('--directed', action='store_true')
2732
argparser.add_argument('--disjoint', action='store_true')
28-
argparser.add_argument('--num_neighbors', type=ast.literal_eval, default=[
29-
[-1],
30-
[15, 10, 5],
31-
[20, 15, 10],
32-
])
33+
argparser.add_argument(
34+
'--num_neighbors',
35+
type=ast.literal_eval,
36+
default=[
37+
[-1],
38+
[15, 10, 5],
39+
[20, 15, 10],
40+
],
41+
)
3342
# TODO(kgajdamo): Enable sampling with replacement
3443
# argparser.add_argument('--replace', action='store_true')
3544
argparser.add_argument('--shuffle', action='store_true')
3645
argparser.add_argument('--biased', action='store_true')
3746
argparser.add_argument('--temporal', action='store_true')
38-
argparser.add_argument('--temporal-strategy', choices=['uniform', 'last'],
39-
default='uniform')
47+
argparser.add_argument(
48+
'--temporal-strategy',
49+
choices=['uniform', 'last'],
50+
default='uniform',
51+
)
4052
argparser.add_argument('--write-csv', action='store_true')
41-
argparser.add_argument('--libraries', nargs="*", type=str,
42-
default=['pyg-lib', 'torch-sparse'])
53+
argparser.add_argument(
54+
'--libraries',
55+
nargs='*',
56+
type=str,
57+
default=['pyg-lib', 'torch-sparse'],
58+
)
4359
args = argparser.parse_args()
4460

4561

@@ -48,7 +64,8 @@
4864
def test_hetero_neighbor(dataset, **kwargs):
4965
if args.temporal and not args.disjoint:
5066
raise ValueError(
51-
"Temporal sampling needs to create disjoint subgraphs")
67+
'Temporal sampling needs to create disjoint subgraphs',
68+
)
5269

5370
colptr_dict, row_dict = dataset
5471
num_nodes_dict = {k[-1]: v.size(0) - 1 for k, v in colptr_dict.items()}
@@ -57,7 +74,8 @@ def test_hetero_neighbor(dataset, **kwargs):
5774
if args.temporal:
5875
# generate random timestamps
5976
node_time, _ = torch.sort(
60-
torch.randint(0, 100000, (num_nodes_dict['paper'], )))
77+
torch.randint(0, 100000, (num_nodes_dict['paper'],)),
78+
)
6179
node_time_dict = {'paper': node_time}
6280
else:
6381
node_time_dict = None
@@ -75,9 +93,10 @@ def test_hetero_neighbor(dataset, **kwargs):
7593
node_perm = torch.arange(0, num_nodes_dict['paper'])
7694

7795
data = defaultdict(list)
78-
for num_neighbors, batch_size in product(args.num_neighbors,
79-
args.batch_sizes):
80-
96+
for num_neighbors, batch_size in product(
97+
args.num_neighbors,
98+
args.batch_sizes,
99+
):
81100
print(f'batch_size={batch_size}, num_neighbors={num_neighbors}):')
82101
data['num_neighbors'].append(num_neighbors)
83102
data['batch-size'].append(batch_size)

0 commit comments

Comments
 (0)