|
1 | 1 | import dataclasses
|
2 | 2 | import logging
|
3 | 3 | import os
|
4 |
| -import tvm_ffi |
5 | 4 | from contextlib import nullcontext
|
| 5 | +from datetime import datetime |
6 | 6 | from pathlib import Path
|
7 | 7 | from typing import Dict, List, Optional, Sequence, Union
|
8 |
| -from datetime import datetime |
9 | 8 |
|
| 9 | +import tvm_ffi |
10 | 10 | from filelock import FileLock
|
11 | 11 |
|
| 12 | +from ..compilation_context import CompilationContext |
12 | 13 | from . import env as jit_env
|
13 | 14 | from .cpp_ext import generate_ninja_build_for_op, run_ninja
|
14 | 15 | from .utils import write_if_different
|
15 |
| -from ..compilation_context import CompilationContext |
16 | 16 |
|
17 | 17 | os.makedirs(jit_env.FLASHINFER_WORKSPACE_DIR, exist_ok=True)
|
18 | 18 | os.makedirs(jit_env.FLASHINFER_CSRC_DIR, exist_ok=True)
|
@@ -194,6 +194,16 @@ def get_library_path(self) -> Path:
|
194 | 194 | return self.aot_path
|
195 | 195 | return self.jit_library_path
|
196 | 196 |
|
| 197 | + def get_object_paths(self) -> List[Path]: |
| 198 | + object_paths = [] |
| 199 | + jit_dir = self.jit_library_path.parent |
| 200 | + for source in self.sources: |
| 201 | + is_cuda = source.suffix == ".cu" |
| 202 | + object_suffix = ".cuda.o" if is_cuda else ".o" |
| 203 | + obj_name = source.with_suffix(object_suffix).name |
| 204 | + object_paths.append(jit_dir / obj_name) |
| 205 | + return object_paths |
| 206 | + |
197 | 207 | @property
|
198 | 208 | def aot_path(self) -> Path:
|
199 | 209 | return jit_env.FLASHINFER_AOT_DIR / self.name / f"{self.name}.so"
|
|
0 commit comments