Skip to content

Commit 5f78377

Browse files
authored
bugfix: deep_gemm artifact load path (#1838)
1 parent 398dc26 commit 5f78377

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

flashinfer/deep_gemm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -934,8 +934,9 @@ def load_all():
934934
if cubin_name in RUNTIME_CACHE:
935935
continue
936936
symbol, sha256 = KERNEL_MAP[cubin_name]
937-
get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256)
938-
path = FLASHINFER_CUBIN_DIR / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
937+
cubin_name = cubin_name + ".cubin"
938+
get_cubin(ArtifactPath.DEEPGEMM + "/" + cubin_name, sha256)
939+
path = FLASHINFER_CUBIN_DIR / ArtifactPath.DEEPGEMM / cubin_name
939940
assert path.exists()
940941
RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol)
941942

@@ -948,8 +949,9 @@ def load(name: str, code: str) -> SM100FP8GemmRuntime:
948949
if cubin_name in RUNTIME_CACHE:
949950
return RUNTIME_CACHE[cubin_name]
950951
symbol, sha256 = KERNEL_MAP[cubin_name]
951-
get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256)
952-
path = FLASHINFER_CUBIN_DIR / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
952+
cubin_name = cubin_name + ".cubin"
953+
get_cubin(ArtifactPath.DEEPGEMM + "/" + cubin_name, sha256)
954+
path = FLASHINFER_CUBIN_DIR / ArtifactPath.DEEPGEMM / cubin_name
953955
assert path.exists()
954956
RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol)
955957
return RUNTIME_CACHE[cubin_name]
@@ -1490,11 +1492,11 @@ def __init__(self, sha256: str):
14901492
self.indice = None
14911493

14921494
def init_indices(self):
1493-
indice_path = ArtifactPath.DEEPGEMM + "kernel_map.json"
1495+
indice_path = ArtifactPath.DEEPGEMM + "/" + "kernel_map.json"
14941496
assert get_cubin(indice_path, self.sha256), (
14951497
"cubin kernel map file not found, nor downloaded with matched sha256"
14961498
)
1497-
path = FLASHINFER_CUBIN_DIR / f"{indice_path}.json"
1499+
path = FLASHINFER_CUBIN_DIR / indice_path
14981500
assert path.exists()
14991501
with open(path, "r") as f:
15001502
self.indice = json.load(f)

0 commit comments

Comments
 (0)