Skip to content

Commit 85445d5

Browse files
authored
jit: add get_object_paths to JitSpec (#1836)
## 📌 Description This commit reverts changes in #1802, and adds a new method `get_object_paths` to JITSpec, which returns the paths of all compiled object files. For applications that want to get the object files rather than the loaded shared library (such as MLC model compilation), this method could be leveraged. ## 🔍 Related Issues N/A ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes
1 parent 35e099e commit 85445d5

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

csrc/batch_decode.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#include <flashinfer/attention/decode.cuh>
1716
#include <flashinfer/attention/scheduler.cuh>
1817
#include <flashinfer/pos_enc.cuh>
1918
#include <flashinfer/utils.cuh>

flashinfer/jit/core.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
import dataclasses
22
import logging
33
import os
4-
import tvm_ffi
54
from contextlib import nullcontext
5+
from datetime import datetime
66
from pathlib import Path
77
from typing import Dict, List, Optional, Sequence, Union
8-
from datetime import datetime
98

9+
import tvm_ffi
1010
from filelock import FileLock
1111

12+
from ..compilation_context import CompilationContext
1213
from . import env as jit_env
1314
from .cpp_ext import generate_ninja_build_for_op, run_ninja
1415
from .utils import write_if_different
15-
from ..compilation_context import CompilationContext
1616

1717
os.makedirs(jit_env.FLASHINFER_WORKSPACE_DIR, exist_ok=True)
1818
os.makedirs(jit_env.FLASHINFER_CSRC_DIR, exist_ok=True)
@@ -194,6 +194,16 @@ def get_library_path(self) -> Path:
194194
return self.aot_path
195195
return self.jit_library_path
196196

197+
def get_object_paths(self) -> List[Path]:
198+
object_paths = []
199+
jit_dir = self.jit_library_path.parent
200+
for source in self.sources:
201+
is_cuda = source.suffix == ".cu"
202+
object_suffix = ".cuda.o" if is_cuda else ".o"
203+
obj_name = source.with_suffix(object_suffix).name
204+
object_paths.append(jit_dir / obj_name)
205+
return object_paths
206+
197207
@property
198208
def aot_path(self) -> Path:
199209
return jit_env.FLASHINFER_AOT_DIR / self.name / f"{self.name}.so"

0 commit comments

Comments
 (0)