Skip to content

Commit 6f8b52e

Browse files
cyx-6yzh119
andauthored
Add ruff to pre-commit (#1201)
<!-- .github/pull_request_template.md --> ## 📌 Description Add ruff to pre-commit and reformat code to pass ruff rules. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Yaxing Cai <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
1 parent 10cba70 commit 6f8b52e

File tree

99 files changed

+877
-811
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+877
-811
lines changed

.pre-commit-config.yaml

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,6 @@ repos:
3535
- id: remove-crlf
3636

3737
# Formatters
38-
- repo: https://github.com/psf/black-pre-commit-mirror
39-
rev: 24.8.0
40-
hooks:
41-
- id: black
42-
exclude: flashinfer/tuning_configs/.*\.py
43-
44-
- repo: https://github.com/pycqa/isort
45-
rev: 5.13.2
46-
hooks:
47-
- id: isort
48-
args: ["--profile=black"] # <-- this one
49-
5038
- repo: https://github.com/pre-commit/mirrors-clang-format
5139
rev: v19.1.1
5240
hooks:
@@ -56,6 +44,16 @@ repos:
5644
(?x)^(3rdparty/.* flashinfer/jit/aot_config.py)$
5745
5846
- repo: https://github.com/pre-commit/mirrors-mypy
59-
rev: '' # Use the sha / tag you want to point at
47+
rev: 'v1.17.1' # Use the sha / tag you want to point at
6048
hooks:
6149
- id: mypy
50+
51+
- repo: https://github.com/astral-sh/ruff-pre-commit
52+
# Ruff version.
53+
rev: v0.12.8
54+
hooks:
55+
# Run the linter.
56+
- id: ruff-check
57+
# Run the formatter.
58+
- id: ruff-format
59+
types_or: [ python, pyi ]

benchmarks/bench_append_paged_kv_cache.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22
import dataclasses
3-
from typing import Tuple, cast
3+
from typing import Tuple
44

55
import numpy as np
66
import torch
@@ -139,10 +139,10 @@ def fn() -> None:
139139
print(
140140
f"model: {model_name:8}",
141141
f"seqlens: {seqlens!r:{seqlen_strlen}}",
142-
f"convert: {convert_latency_ms*1e3:2.0f}us",
143-
f"1layer: {latency_ms*1e3:2.0f}us",
144-
f"{model.num_layers}layers: {all_layers_latency_ms*1e3:3.0f}us",
145-
f"throughput: {throughput*1e-9:8.3f}GB/s",
142+
f"convert: {convert_latency_ms * 1e3:2.0f}us",
143+
f"1layer: {latency_ms * 1e3:2.0f}us",
144+
f"{model.num_layers}layers: {all_layers_latency_ms * 1e3:3.0f}us",
145+
f"throughput: {throughput * 1e-9:8.3f}GB/s",
146146
)
147147
print("---")
148148

benchmarks/bench_append_paged_mla_kv_cache.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22
import dataclasses
3-
from typing import Tuple, cast
3+
from typing import Tuple
44

55
import numpy as np
66
import torch
@@ -122,10 +122,10 @@ def fn() -> None:
122122
print(
123123
f"model: {model_name:8}",
124124
f"seqlens: {seqlens!r:{seqlen_strlen}}",
125-
f"convert: {convert_latency_ms*1e3:2.0f}us",
126-
f"1layer: {latency_ms*1e3:2.0f}us",
127-
f"{model.num_layers}layers: {all_layers_latency_ms*1e3:3.0f}us",
128-
f"throughput: {throughput*1e-9:8.3f}GB/s",
125+
f"convert: {convert_latency_ms * 1e3:2.0f}us",
126+
f"1layer: {latency_ms * 1e3:2.0f}us",
127+
f"{model.num_layers}layers: {all_layers_latency_ms * 1e3:3.0f}us",
128+
f"throughput: {throughput * 1e-9:8.3f}GB/s",
129129
)
130130
print("---")
131131

benchmarks/bench_batch_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def main() -> None:
145145
sweep["num_kv_heads"],
146146
sweep["num_qo_heads"],
147147
):
148-
149148
ms_old, ms_new, mem_MB, bw_old, bw_new = run_bench(
150149
kv_lens,
151150
qo_lens,

benchmarks/bench_batch_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def bench_batch_decode(
7575
f"batch_size={batch_size}, seq_len={seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_block_size={page_block_size}, q_dtype={q_dtype}, kv_dtype={kv_dtype}"
7676
)
7777
print(f"execution time: {ms}ms")
78-
print(f"memory bandwidth: {io / ms / 1024 / 1024 :.2f} GB/s")
78+
print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s")
7979

8080

8181
if __name__ == "__main__":

benchmarks/bench_blackwell_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def bench_fmha_blackwell(
6161
q_data_type=dtype,
6262
kv_data_type=dtype,
6363
)
64-
o = wrapper.run(q, k, v)
64+
_o = wrapper.run(q, k, v)
6565
measurements = bench_gpu_time(
6666
lambda: wrapper.run(q, k, v),
6767
dry_run_time_ms=100,

benchmarks/bench_cutlass_fused_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def bench_cutlass_fused_moe(
8484
n = intermediate_size
8585
k = hidden_size
8686
otype = torch.bfloat16
87-
wtype = torch.float8_e4m3fn
8887
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10
8988
w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous()
9089

benchmarks/bench_fused_add_rmsnorm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
from typing import cast
32

43
import numpy as np
54
import torch
@@ -54,8 +53,8 @@ def fn() -> None:
5453
f"batch_size: {batch_size:3},",
5554
f"hidden_size: {hidden_size:5},",
5655
f"dtype: {dtype_str:8},",
57-
f"latency: {latency_ms*1e3:2.0f}us,",
58-
f"throughput: {throughput*1e-9:7.3f}GB/s",
56+
f"latency: {latency_ms * 1e3:2.0f}us,",
57+
f"throughput: {throughput * 1e-9:7.3f}GB/s",
5958
)
6059

6160
print("---")

benchmarks/bench_mixed_attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,7 @@ def run_bench(
213213
for idx, (p_q_lens, p_kv_lens, d_q_len, d_kv_len) in enumerate(
214214
zip(p_q_configs, p_kv_configs, d_q_len_configs, d_kv_len_configs)
215215
):
216-
217-
print(f"===== Benchmark {idx+1}: (kv_len, qo_len) set =====")
216+
print(f"===== Benchmark {idx + 1}: (kv_len, qo_len) set =====")
218217
run_bench(
219218
p_q_lens,
220219
p_kv_lens,

benchmarks/bench_renorm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def main():
4949
io = (probs.numel() * probs.element_size()) * 2
5050
bandwidth = io * 1e-6 / ms
5151
print(
52-
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
52+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
5353
)
5454

5555
print("---")
@@ -75,7 +75,7 @@ def main():
7575
io = (probs.numel() * probs.element_size()) * 2
7676
bandwidth = io * 1e-6 / ms
7777
print(
78-
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
78+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
7979
)
8080

8181
print("---")
@@ -100,7 +100,7 @@ def main():
100100
io = (logits.numel() * logits.element_size()) * 2
101101
bandwidth = io * 1e-6 / ms
102102
print(
103-
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms*1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
103+
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
104104
)
105105

106106

0 commit comments

Comments
 (0)