22
22
from typing import Generator
23
23
import requests # type: ignore[import-untyped]
24
24
import shutil
25
- import hashlib
26
25
27
26
from .jit .core import logger
28
27
from .jit .cubin_loader import (
@@ -85,6 +84,7 @@ class ArtifactPath:
85
84
CUDNN_SDPA : str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
86
85
DEEPGEMM : str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
87
86
87
+
88
88
# TODO: Should be deprecated
89
89
@dataclass (frozen = True )
90
90
class MetaInfoHash :
@@ -102,12 +102,18 @@ class MetaInfoHash:
102
102
103
103
# @dataclass(frozen=True)
104
104
class CheckSumHash :
105
- TRTLLM_GEN_FMHA : str = "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
106
- TRTLLM_GEN_BMM : str = "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
107
- DEEPGEMM : str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
108
- TRTLLM_GEN_GEMM : str = "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
105
+ TRTLLM_GEN_FMHA : str = (
106
+ "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
107
+ )
108
+ TRTLLM_GEN_BMM : str = (
109
+ "efb9379c924193f6d3cb792bafb12b0811cab8eaa12bf324c7c410636c7769cd"
110
+ )
111
+ DEEPGEMM : str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
112
+ TRTLLM_GEN_GEMM : str = (
113
+ "e475e37989eed16418e0e858e2868ff07cb4b650cc48759cc23012f1afea310a"
114
+ )
109
115
110
- map_checksums : [ dict [str , str ] ] = {
116
+ map_checksums : dict [str , str ] = {
111
117
safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "checksums.txt" ): TRTLLM_GEN_FMHA ,
112
118
safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "checksums.txt" ): TRTLLM_GEN_BMM ,
113
119
safe_urljoin (ArtifactPath .DEEPGEMM , "checksums.txt" ): DEEPGEMM ,
@@ -118,7 +124,9 @@ class CheckSumHash:
118
124
def get_checksums (subdirs ):
119
125
checksums = {}
120
126
for subdir in subdirs :
121
- uri = safe_urljoin (FLASHINFER_CUBINS_REPOSITORY , safe_urljoin (subdir , "checksums.txt" ))
127
+ uri = safe_urljoin (
128
+ FLASHINFER_CUBINS_REPOSITORY , safe_urljoin (subdir , "checksums.txt" )
129
+ )
122
130
checksum_path = FLASHINFER_CUBIN_DIR / safe_urljoin (subdir , "checksums.txt" )
123
131
download_file (uri , checksum_path )
124
132
with open (checksum_path , "r" ) as f :
@@ -132,7 +140,7 @@ def get_checksums(subdirs):
132
140
return checksums
133
141
134
142
135
- def get_subdir_file_list ():
143
+ def get_subdir_file_list () -> Generator [ tuple [ str , str ], None , None ] :
136
144
base = FLASHINFER_CUBINS_REPOSITORY
137
145
138
146
cubin_dirs = [
@@ -146,9 +154,24 @@ def get_subdir_file_list():
146
154
checksums = get_checksums (cubin_dirs )
147
155
148
156
# The meta info header files first.
149
- yield (safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" ), checksums [safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" )])
150
- yield (safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" ), checksums [safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" )])
151
- yield (safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" ), checksums [safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" )])
157
+ yield (
158
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" ),
159
+ checksums [
160
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_FMHA , "include/flashInferMetaInfo.h" )
161
+ ],
162
+ )
163
+ yield (
164
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" ),
165
+ checksums [
166
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_GEMM , "include/flashinferMetaInfo.h" )
167
+ ],
168
+ )
169
+ yield (
170
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" ),
171
+ checksums [
172
+ safe_urljoin (ArtifactPath .TRTLLM_GEN_BMM , "include/flashinferMetaInfo.h" )
173
+ ],
174
+ )
152
175
153
176
# All the actual kernel cubin's.
154
177
for cubin_dir in cubin_dirs :
0 commit comments