Skip to content

Commit b68772e

Browse files
committed
fix types
1 parent 71c84c8 commit b68772e

File tree

1 file changed

+34
-11
lines changed

1 file changed

+34
-11
lines changed

flashinfer/artifacts.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from typing import Generator
2323
import requests # type: ignore[import-untyped]
2424
import shutil
25-
import hashlib
2625

2726
from .jit.core import logger
2827
from .jit.cubin_loader import (
@@ -85,6 +84,7 @@ class ArtifactPath:
8584
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
8685
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
8786

87+
8888
# TODO: Should be deprecated
8989
@dataclass(frozen=True)
9090
class MetaInfoHash:
@@ -102,12 +102,18 @@ class MetaInfoHash:
102102

103103
# @dataclass(frozen=True)
104104
class CheckSumHash:
105-
TRTLLM_GEN_FMHA: str = "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
106-
TRTLLM_GEN_BMM: str = "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
107-
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
108-
TRTLLM_GEN_GEMM: str = "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
105+
TRTLLM_GEN_FMHA: str = (
106+
"b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
107+
)
108+
TRTLLM_GEN_BMM: str = (
109+
"efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
110+
)
111+
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
112+
TRTLLM_GEN_GEMM: str = (
113+
"e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
114+
)
109115

110-
map_checksums: [dict[str, str]] = {
116+
map_checksums: dict[str, str] = {
111117
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA,
112118
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM,
113119
safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM,
@@ -118,7 +124,9 @@ class CheckSumHash:
118124
def get_checksums(subdirs):
119125
checksums = {}
120126
for subdir in subdirs:
121-
uri = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt"))
127+
uri = safe_urljoin(
128+
FLASHINFER_CUBINS_REPOSITORY, safe_urljoin(subdir, "checksums.txt")
129+
)
122130
checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin(subdir, "checksums.txt")
123131
download_file(uri, checksum_path)
124132
with open(checksum_path, "r") as f:
@@ -132,7 +140,7 @@ def get_checksums(subdirs):
132140
return checksums
133141

134142

135-
def get_subdir_file_list():
143+
def get_subdir_file_list() -> Generator[tuple[str, str], None, None]:
136144
base = FLASHINFER_CUBINS_REPOSITORY
137145

138146
cubin_dirs = [
@@ -146,9 +154,24 @@ def get_subdir_file_list():
146154
checksums = get_checksums(cubin_dirs)
147155

148156
# The meta info header files first.
149-
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")])
150-
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")])
151-
yield (safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"), checksums[safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")])
157+
yield (
158+
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h"),
159+
checksums[
160+
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "include/flashInferMetaInfo.h")
161+
],
162+
)
163+
yield (
164+
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h"),
165+
checksums[
166+
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "include/flashinferMetaInfo.h")
167+
],
168+
)
169+
yield (
170+
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h"),
171+
checksums[
172+
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "include/flashinferMetaInfo.h")
173+
],
174+
)
152175

153176
# All the actual kernel cubin's.
154177
for cubin_dir in cubin_dirs:

0 commit comments

Comments
 (0)