Skip to content

Commit c0b83c7

Browse files
committed
fix types
1 parent 6d0f0ca commit c0b83c7

File tree

1 file changed

+34
-11
lines changed

1 file changed

+34
-11
lines changed

flashinfer/artifacts.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from typing import Generator
2424
import requests # type: ignore[import-untyped]
2525
import shutil
26-
import hashlib
2726

2827
# Create logger for artifacts module to avoid circular import with jit.core
2928
logger = logging.getLogger("flashinfer.artifacts")
@@ -91,6 +90,7 @@ class ArtifactPath:
9190
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
9291
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
9392

93+
9494
# TODO: Should be deprecated
9595
@dataclass(frozen=True)
9696
class MetaInfoHash:
@@ -108,12 +108,18 @@ class MetaInfoHash:
108108

109109
# @dataclass(frozen=True)
110110
class CheckSumHash:
111-
TRTLLM_GEN_FMHA: str = "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
112-
TRTLLM_GEN_BMM: str = "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
113-
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
114-
TRTLLM_GEN_GEMM: str = "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
111+
TRTLLM_GEN_FMHA: str = (
112+
"b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
113+
)
114+
TRTLLM_GEN_BMM: str = (
115+
"efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
116+
)
117+
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
118+
TRTLLM_GEN_GEMM: str = (
119+
"e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
120+
)
115121

116-
map_checksums: [dict[str, str]] = {
122+
map_checksums: dict[str, str] = {
117123
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA,
118124
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM,
119125
safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM,
@@ -124,7 +130,9 @@ class CheckSumHash:
124130
def get_checksums(subdirs):
125131
checksums = {}
126132
for subdir in subdirs:
127-
uri = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt"))
133+
uri = safe_urljoin(
134+
FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt")
135+
)
128136
checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin(subdir, "checksums.txt")
129137
download_file(uri, checksum_path)
130138
with open(checksum_path, "r") as f:
@@ -138,7 +146,7 @@ def get_checksums(subdirs):
138146
return checksums
139147

140148

141-
def get_subdir_file_list():
149+
def get_subdir_file_list() -> Generator[tuple[str, str], None, None]:
142150
base = FLASHINFER_CUBINS_REPOSITORY
143151

144152
cubin_dirs = [
@@ -152,9 +160,24 @@ def get_subdir_file_list():
152160
checksums = get_checksums(cubin_dirs)
153161

154162
# The meta info header files first.
155-
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")])
156-
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")])
157-
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")])
163+
yield (
164+
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"),
165+
checksums[
166+
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")
167+
],
168+
)
169+
yield (
170+
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"),
171+
checksums[
172+
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")
173+
],
174+
)
175+
yield (
176+
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"),
177+
checksums[
178+
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")
179+
],
180+
)
158181

159182
# All the actual kernel cubin's.
160183
for cubin_dir in cubin_dirs:

0 commit comments

Comments
 (0)