Skip to content

Commit 85d75ca

Browse files
authored
Add mypy to pre-commit (#1179)
1 parent 73702b7 commit 85d75ca

34 files changed

+180
-196
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,8 @@ repos:
5454
types_or: [c++, c, cuda]
5555
exclude: |
5656
(?x)^(3rdparty/.* flashinfer/jit/aot_config.py)$
57+
58+
- repo: https://github.com/pre-commit/mirrors-mypy
59+
rev: '' # Use the sha / tag you want to point at
60+
hooks:
61+
- id: mypy

benchmarks/bench_mixed_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ def run_bench(
185185
full_kv_len = np.random.randint(2000, 16000, size=bsz)
186186
p_q_lens = []
187187
p_kv_lens = []
188-
d_q_len = []
189-
d_kv_len = []
188+
d_q_lens = []
189+
d_kv_lens = []
190190

191191
for i in range(bsz):
192192
if i % stride == 0:

ci/scripts/jenkins/git_utils.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030

3131
def compress_query(query: str) -> str:
3232
query = query.replace("\n", "")
33-
query = re.sub("\s+", " ", query)
33+
query = re.sub(r"\s+", " ", query)
3434
return query
3535

3636

3737
def post(url: str, body: Optional[Any] = None, auth: Optional[Tuple[str, str]] = None):
3838
logging.info(f"Requesting POST to {url} with {body}")
39-
headers = {}
39+
headers: Dict[Any, Any] = {}
4040
req = request.Request(url, headers=headers, method="POST")
4141
if auth is not None:
4242
auth_str = base64.b64encode(f"{auth[0]}:{auth[1]}".encode())
@@ -46,9 +46,8 @@ def post(url: str, body: Optional[Any] = None, auth: Optional[Tuple[str, str]] =
4646
body = ""
4747

4848
req.add_header("Content-Type", "application/json; charset=utf-8")
49-
data = json.dumps(body)
50-
data = data.encode("utf-8")
51-
req.add_header("Content-Length", len(data))
49+
data = json.dumps(body).encode("utf-8")
50+
req.add_header("Content-Length", str(len(data)))
5251

5352
with request.urlopen(req, data) as response:
5453
return response.read()
@@ -119,9 +118,8 @@ def _request(
119118
logging.info(f"Requesting {method} to {full_url} with {body}")
120119
req = request.Request(full_url, headers=self.headers(), method=method.upper())
121120
req.add_header("Content-Type", "application/json; charset=utf-8")
122-
data = json.dumps(body)
123-
data = data.encode("utf-8")
124-
req.add_header("Content-Length", len(data))
121+
data = json.dumps(body).encode("utf-8")
122+
req.add_header("Content-Length", str(len(data)))
125123

126124
try:
127125
with request.urlopen(req, data) as response:
@@ -206,12 +204,11 @@ def find_ccs(body: str) -> List[str]:
206204
matches = re.findall(r"(cc( @[-A-Za-z0-9]+)+)", body, flags=re.MULTILINE)
207205
matches = [full for full, last in matches]
208206

209-
reviewers = []
207+
reviewers = set()
210208
for match in matches:
211209
if match.startswith("cc "):
212210
match = match.replace("cc ", "")
213211
users = [x.strip() for x in match.split("@")]
214-
reviewers += users
212+
reviewers.update(users)
215213

216-
reviewers = set(x for x in reviewers if x != "")
217-
return list(reviewers)
214+
return [x for x in reviewers if x != ""]

docs/conf.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import warnings
44
from pathlib import Path
5+
from typing import Any, List
56

67
# import tlcpack_sphinx_addon
78
# Configuration file for the Sphinx documentation builder.
@@ -64,16 +65,12 @@
6465

6566
html_theme = "furo" # "sphinx_rtd_theme"
6667

67-
templates_path = []
68+
templates_path: List[Any] = []
6869

69-
html_static_path = []
70+
html_static_path = ["_static"]
7071

7172
html_theme_options = {
7273
"logo_only": True,
73-
}
74-
75-
html_static_path = ["_static"]
76-
html_theme_options = {
7774
"light_logo": "FlashInfer-white-background.png",
7875
"dark_logo": "FlashInfer-black-background.png",
7976
}

flashinfer/artifacts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import time
2020
from concurrent.futures import ThreadPoolExecutor, as_completed
2121

22-
import requests
22+
import requests # type: ignore[import-untyped]
2323

2424
from .jit.core import logger
2525
from .jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY, get_cubin

flashinfer/autotuner.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -583,20 +583,22 @@ def _optimization_profiles(
583583

584584
generated_profiles: List[OptimizationProfile] = []
585585

586-
dynamic_dims = []
586+
dynamic_dims: List[Tuple[Any, ...]] = []
587587

588588
for spec in tuning_config.dynamic_tensor_specs:
589589
assert inspect.isfunction(spec.gen_tuning_buckets) or isinstance(
590590
spec.gen_tuning_buckets, (list, tuple)
591591
), "The given dynamic dimension must provide a opt value generation function or a list of opt values"
592592
if inspect.isfunction(spec.gen_tuning_buckets):
593593
opt_shapes = spec.gen_tuning_buckets(
594-
base_profile.shapes[spec.input_idx][spec.dim_idx].val
594+
base_profile.shapes[spec.input_idx][spec.dim_idx]._opt()
595595
)
596596
else:
597597
opt_shapes = spec.gen_tuning_buckets
598-
opt_shapes_max = tuple(opt_shapes[1:]) + (float("inf"),)
599-
opt_shapes_max = {v1: v2 for v1, v2 in zip(opt_shapes, opt_shapes_max)}
598+
opt_shapes_max = {
599+
v1: v2
600+
for v1, v2 in zip(opt_shapes, tuple(opt_shapes[1:]) + (float("inf"),))
601+
}
600602
dynamic_dims.append(
601603
(spec.input_idx, spec.dim_idx, opt_shapes_max, opt_shapes)
602604
)
@@ -617,10 +619,12 @@ def _optimization_profiles(
617619
)
618620

619621
# Adjust the profile to satisfy the constraints
620-
for spec in tuning_config.constraint_specs:
621-
min_value = opt_value = max_value = spec.infer_shape(p.get_opt_shapes())
622-
p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim(
623-
min_value, opt_value, max_value
622+
for constraint_spec in tuning_config.constraint_specs:
623+
min_value = opt_value = max_value = constraint_spec.infer_shape(
624+
p.get_opt_shapes()
625+
)
626+
p.shapes[constraint_spec.input_idx][constraint_spec.dim_idx] = (
627+
DynamicDim(min_value, opt_value, max_value)
624628
)
625629
generated_profiles.append(p)
626630
logger.debug(f"[Autotuner]: generated profile: {p}")
@@ -651,8 +655,8 @@ def _find_nearest_profile(
651655
)
652656

653657
# associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile
654-
for spec in tuning_config.constraint_specs:
655-
base_profile[spec.input_idx][spec.dim_idx] = -1
658+
for constraint_spec in tuning_config.constraint_specs:
659+
base_profile[constraint_spec.input_idx][constraint_spec.dim_idx] = -1
656660

657661
return tuple(tuple(shape) for shape in base_profile)
658662

flashinfer/comm/mnnvl.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import platform
2020
import sys
2121
from dataclasses import dataclass
22-
from typing import List, Optional
22+
from typing import Any, Dict, List, Optional
2323

2424
import torch
2525
from cuda import cuda
@@ -132,16 +132,17 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:
132132

133133

134134
if IS_BUILDING_DOCS:
135+
# Mock classes for building docs
135136

136-
class MpiComm:
137+
class MpiComm: # type: ignore[no-redef]
137138
@classmethod
138139
def set_mpi_comm(cls, new_comm):
139140
pass
140141

141142
def __getattr__(self, name):
142143
return None
143144

144-
class MnnvlMemory:
145+
class MnnvlMemory: # type: ignore[no-redef]
145146
initialized: bool = False
146147

147148
current_mem_offset: int = 0
@@ -159,8 +160,8 @@ class MnnvlMemory:
159160

160161
dev_id: int = None
161162

162-
allocated_map = {}
163-
address_refcnt = {}
163+
allocated_map: Dict[int, Any] = {}
164+
address_refcnt: Dict[int, Any] = {}
164165

165166
def __init__(self, mapping: Mapping, size: int):
166167
pass
@@ -211,7 +212,7 @@ def supports_mnnvl() -> bool:
211212
import pynvml
212213
from mpi4py import MPI
213214

214-
class MpiComm:
215+
class MpiComm: # type: ignore[no-redef]
215216
_comm: MPI.Intracomm = MPI.COMM_WORLD
216217

217218
@classmethod
@@ -221,7 +222,7 @@ def set_mpi_comm(cls, new_comm: MPI.Intracomm):
221222
def __getattr__(self, name):
222223
return getattr(self._comm, name)
223224

224-
class MnnvlMemory:
225+
class MnnvlMemory: # type: ignore[no-redef]
225226
initialized: bool = False
226227

227228
current_mem_offset: int = 0
@@ -239,8 +240,8 @@ class MnnvlMemory:
239240

240241
dev_id: int = None
241242

242-
allocated_map = {}
243-
address_refcnt = {}
243+
allocated_map: Dict[int, Any] = {}
244+
address_refcnt: Dict[int, Any] = {}
244245

245246
def __init__(self, mapping: Mapping, size: int):
246247
self.mapping = mapping

flashinfer/comm/trtllm_ar.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import functools
1818
import logging
19+
from ctypes import c_void_p
1920
from dataclasses import dataclass
2021
from types import SimpleNamespace
2122
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -417,7 +418,7 @@ def trtllm_create_ipc_workspace_for_all_reduce(
417418
max_token_num: int,
418419
hidden_dim,
419420
group: Optional[ProcessGroup] = None,
420-
) -> List[int]:
421+
) -> List[List[int]]:
421422
"""
422423
Parameters:
423424
- rank: the rank of the current process.
@@ -492,7 +493,7 @@ def trtllm_create_ipc_workspace_for_all_reduce(
492493

493494

494495
def trtllm_destroy_ipc_workspace_for_all_reduce(
495-
workspace: List[int], group: Optional[ProcessGroup] = None
496+
workspace: List[List[int]], group: Optional[ProcessGroup] = None
496497
) -> None:
497498
"""
498499
Note:
@@ -518,7 +519,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
518519
hidden_dim,
519520
use_fp32_lamport: bool = False,
520521
group: Optional[ProcessGroup] = None,
521-
) -> List[int]:
522+
) -> Tuple[List[List[int]], torch.Tensor]:
522523
"""
523524
Parameters:
524525
- tp_rank: the rank of the current process.
@@ -564,7 +565,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
564565
# we should init 3 buffers for all reduce fusion:
565566
# [buffer_size, flag_size, lamport_buffer_size]
566567

567-
ipc_handles = list()
568+
ipc_handles: List[List[int]] = list()
568569
for size in [buffer_size, flag_size, lamport_buffer_size]:
569570
# todo(review): confirm we need this alignment
570571
# all sizes should be aligned to 1LU << 21 bytes (2MB)
@@ -609,7 +610,9 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
609610
cudart.cudaMemset(flag_ptr, 0, 5 * 4)
610611
# Set flag_ptr[3] = lamport_comm_size
611612
lamport_comm_size_bytes = lamport_comm_size.to_bytes(4, byteorder="little")
612-
cudart.cudaMemcpy(flag_ptr.value + 3 * 4, lamport_comm_size_bytes, 4)
613+
cudart.cudaMemcpy(
614+
c_void_p(flag_ptr.value + 3 * 4), c_void_p(lamport_comm_size_bytes), 4
615+
)
613616
print("set flag_ptr[3] = lamport_comm_size: ", lamport_comm_size)
614617
# add flag_ptr to workspace
615618
workspace.append(flag_ptr.value)
@@ -628,7 +631,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
628631

629632

630633
def trtllm_destroy_ipc_workspace_for_all_reduce_fusion(
631-
workspace: List[int], group: Optional[ProcessGroup] = None
634+
workspace: List[List[int]], group: Optional[ProcessGroup] = None
632635
) -> None:
633636
"""
634637
Parameters:

flashinfer/cudnn/prefill.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,7 @@ def _build_prefill_graph(
143143
if q.dim() == 3:
144144
h_qo, d_qk = q.shape[1], q.shape[2]
145145
elif q.dim() == 4:
146-
h_qo, d_qk = (
147-
q.shape[1],
148-
q.shape[2],
149-
q.shape[3],
150-
)
146+
h_qo, d_qk = q.shape[2], q.shape[3]
151147
else:
152148
raise ValueError(f"Invalid query tensor shape: {q.shape}")
153149

flashinfer/decode.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,18 @@ def run(
17911791

17921792

17931793
class TrtllmGenDecodeModule:
1794+
def __init__(self) -> None:
1795+
self._sm_count: Optional[int] = None
1796+
self._mod = trtllm_gen_fmha_module()
1797+
self._op = self._mod.build_and_load()
1798+
from flashinfer.jit.cubin_loader import (
1799+
setup_cubin_loader,
1800+
setup_metainfo_loader,
1801+
)
1802+
1803+
setup_cubin_loader(self._mod.get_library_path())
1804+
setup_metainfo_loader(self._mod.get_library_path())
1805+
17941806
def _paged_run(
17951807
self,
17961808
query: torch.Tensor,
@@ -1836,18 +1848,6 @@ def _paged_run(
18361848
def _plan(self, *args, **kwargs):
18371849
pass
18381850

1839-
def __init__(self):
1840-
self._sm_count: Optional[int] = None
1841-
self._mod = trtllm_gen_fmha_module()
1842-
self._op = self._mod.build_and_load()
1843-
from flashinfer.jit.cubin_loader import (
1844-
setup_cubin_loader,
1845-
setup_metainfo_loader,
1846-
)
1847-
1848-
setup_cubin_loader(self._mod.get_library_path())
1849-
setup_metainfo_loader(self._mod.get_library_path())
1850-
18511851

18521852
@functools.cache
18531853
def get_trtllm_gen_decode_module(*args):

0 commit comments

Comments
 (0)