Skip to content

Commit 8587c21

Browse files
cyx-6gemini-code-assist[bot]yyihuang
authored
refactor: Improved metainfo for trtllm-gen fmha (#1292)
Refactor the metainfo for trtllm gen fmha. <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Yaxing Cai <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Yingyi Huang <[email protected]>
1 parent fe29ed6 commit 8587c21

File tree

6 files changed

+168
-2284
lines changed

6 files changed

+168
-2284
lines changed

flashinfer/decode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
get_batch_prefill_uri,
3434
get_single_decode_uri,
3535
setup_cubin_loader,
36+
setup_metainfo_loader,
3637
trtllm_fmha_gen_module,
3738
trtllm_mla_gen_module,
3839
)
@@ -302,6 +303,7 @@ def get_trtllm_fmha_gen_module():
302303
mod = trtllm_fmha_gen_module()
303304
op = mod.build_and_load()
304305
setup_cubin_loader(mod.get_library_path())
306+
setup_metainfo_loader(mod.get_library_path())
305307
return op
306308

307309

@@ -310,6 +312,7 @@ def get_trtllm_mla_gen_module():
310312
mod = trtllm_mla_gen_module()
311313
op = mod.build_and_load()
312314
setup_cubin_loader(mod.get_library_path())
315+
setup_metainfo_loader(mod.get_library_path())
313316
return op
314317

315318

flashinfer/jit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from .core import gen_jit_spec as gen_jit_spec
7070
from .core import sm90a_nvcc_flags as sm90a_nvcc_flags
7171
from .core import sm100a_nvcc_flags as sm100a_nvcc_flags
72-
from .cubin_loader import setup_cubin_loader
72+
from .cubin_loader import setup_cubin_loader, setup_metainfo_loader
7373

7474

7575
@functools.cache

flashinfer/jit/cubin_loader.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,32 @@ def get_cubin_callback(name, sha256):
188188
dll_cubin_handlers[dll_path] = cb
189189

190190
_LIB.FlashInferSetCubinCallback(cb)
191+
192+
193+
dll_metainfo_handlers = {}
194+
195+
196+
def setup_metainfo_loader(dll_path: str):
197+
if dll_path in dll_metainfo_handlers:
198+
return
199+
200+
_LIB = ctypes.CDLL(dll_path)
201+
202+
# Define the correct callback type
203+
CALLBACK_TYPE = ctypes.CFUNCTYPE(
204+
None, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p
205+
)
206+
207+
def get_metainfo_callback(name, sha256, extension):
208+
metainfo = get_cubin(
209+
name.decode("utf-8"), sha256.decode("utf-8"), extension.decode("utf-8")
210+
)
211+
_LIB.FlashInferSetCurrentMetaInfo(
212+
convert_to_ctypes_char_p(metainfo), ctypes.c_int(len(metainfo))
213+
)
214+
215+
# Create the callback and keep a reference to prevent GC
216+
cb = CALLBACK_TYPE(get_metainfo_callback)
217+
dll_metainfo_handlers[dll_path] = cb
218+
219+
_LIB.FlashInferSetMetaInfoCallback(cb)

include/flashinfer/cubin_loader.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,31 @@ std::string getCubin(const std::string& name, const std::string& sha256) {
5656
callbackGetCubin(name.c_str(), sha256.c_str());
5757
return current_cubin;
5858
}
59+
60+
void (*callbackGetMetaInfo)(const char* path, const char* sha256, const char* extension) = nullptr;
61+
62+
// Set the python callback, called by the python code using ctypes.
63+
extern "C" void FlashInferSetMetaInfoCallback(void (*callback)(const char* path, const char* sha256,
64+
const char* extension)) {
65+
callbackGetMetaInfo = callback;
66+
}
67+
68+
// Thread-local variable that stores the current metainfo.
69+
// It is reset on every call to `getMetaInfo()`.
70+
thread_local std::string raw_metainfo;
71+
72+
// Called by the callback to set the current metainfo.
73+
extern "C" void FlashInferSetCurrentMetaInfo(const char* binary, int size) {
74+
raw_metainfo = std::string(binary, size);
75+
}
76+
77+
// Get the metainfo from the python callback.
78+
// This is the API for the native library to use.
79+
std::string getMetaInfo(const std::string& name, const std::string& sha256,
80+
const std::string& extension) {
81+
if (!callbackGetMetaInfo) {
82+
throw std::runtime_error("FlashInferSetMetaInfoCallback not set");
83+
}
84+
callbackGetMetaInfo(name.c_str(), sha256.c_str(), extension.c_str());
85+
return raw_metainfo;
86+
}

0 commit comments

Comments
 (0)