Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions rtp_llm/models_py/modules/hybrid/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ test_envs = {
"DEVICE_RESERVE_MEMORY_BYTES": "512000000", # 512MB
}

flashinfer_test_envs = {
"DEVICE_RESERVE_MEMORY_BYTES": "512000000", # 512MB
"FLASHINFER_JIT_VERBOSE": "1", # Enable verbose JIT compilation logs
}

py_test_deps = [
"//rtp_llm/models_py/standalone:py_standalone_testlib",
]
Expand Down Expand Up @@ -54,3 +59,18 @@ py_test (
tags = ["open_skip", "H20"],
exec_properties = {'gpu':'H20'},
)

py_test (
name = "flashinfer_jit_test",
srcs = ["flashinfer_jit_test.py"],
deps = [
"//rtp_llm/models_py:models",
"//rtp_llm:config",
"//rtp_llm:utils",
"//rtp_llm:testlib",
"//rtp_llm/test/model_test/test_util:test_util"
],
env = flashinfer_test_envs,
tags = ["open_skip", "H20"],
exec_properties = {'gpu':'H20'},
)
101 changes: 101 additions & 0 deletions rtp_llm/models_py/modules/hybrid/test/flashinfer_jit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import logging
import os
import sys
import unittest

import torch

from rtp_llm.models_py.modules.factory.attention.cuda_mla_impl.flashinfer_mla import (
warmup_flashinfer_python,
)

# 应用torch_patch来修复flashinfer JIT编译路径问题
from rtp_llm.utils import torch_patch # noqa: F401


class FlashInferJitTest(unittest.TestCase):
"""测试flashinfer JIT功能是否在Bazel环境中正常工作"""

def setUp(self):
if not torch.cuda.is_available():
self.skipTest("CUDA is not available")

# 设置日志级别
logging.basicConfig(level=logging.INFO)

def test_flashinfer_jit_warmup(self):
"""测试flashinfer JIT预热和编译功能"""
try:
# 验证 torch 和 flashinfer 是从 .cache 导入而不是 Bazel runfiles
import flashinfer

torch_path = torch.__file__
flashinfer_path = flashinfer.__file__

logging.info(f"Torch imported from: {torch_path}")
logging.info(f"FlashInfer imported from: {flashinfer_path}")

# 断言:torch 路径不能包含 runfiles
self.assertNotIn(
"runfiles",
torch_path,
f"Torch should not be imported from Bazel runfiles, but got: {torch_path}",
)
# 断言:torch 路径应该包含 .cache
self.assertIn(
".cache",
torch_path,
f"Torch should be imported from .cache, but got: {torch_path}",
)

# 断言:flashinfer 路径不能包含 runfiles
self.assertNotIn(
"runfiles",
flashinfer_path,
f"FlashInfer should not be imported from Bazel runfiles, but got: {flashinfer_path}",
)
# 断言:flashinfer 路径应该包含 .cache
self.assertIn(
".cache",
flashinfer_path,
f"FlashInfer should be imported from .cache, but got: {flashinfer_path}",
)

logging.info("✓ Package import paths validated successfully")

# 测试torch.utils.cpp_extension路径是否正确配置
import torch.utils.cpp_extension as cpp_ext

logging.info(
f"warmup_flashinfer_python before update cpp_extension.include_paths: _HERE = {cpp_ext._HERE}"
)
logging.info(
f"warmup_flashinfer_python before update cpp_extension.include_paths: _TORCH_PATH = {cpp_ext._TORCH_PATH}"
)
logging.info(
f"warmup_flashinfer_python before update cpp_extension.include_paths: TORCH_LIB_PATH = {cpp_ext.TORCH_LIB_PATH}"
)

# 执行flashinfer预热,这会触发JIT编译
warmup_flashinfer_python()

logging.info(
f"warmup_flashinfer_python cpp_extension.include_paths: _HERE = {cpp_ext._HERE}"
)
logging.info(
f"warmup_flashinfer_python cpp_extension.include_paths: _TORCH_PATH = {cpp_ext._TORCH_PATH}"
)
logging.info(
f"warmup_flashinfer_python cpp_extension.include_paths: TORCH_LIB_PATH = {cpp_ext.TORCH_LIB_PATH}"
)

logging.info("FlashInfer JIT warmup completed successfully")

except ImportError as e:
self.skipTest(f"FlashInfer not available: {e}")
except Exception as e:
self.fail(f"FlashInfer JIT warmup failed: {e}")


if __name__ == "__main__":
unittest.main()
5 changes: 4 additions & 1 deletion rtp_llm/test/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ py_binary(
name = "gpu_lock",
deps = [":torch"],
main = "device_resource.py",
srcs = ["device_resource.py"],
srcs = [
"device_resource.py",
"jit_sys_path_setup.py",
],
)

py_library(
Expand Down
6 changes: 5 additions & 1 deletion rtp_llm/test/utils/device_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,17 @@ def __exit__(self, *args: Any):

if __name__ == "__main__":
cuda_info = get_cuda_info()

if not cuda_info:
logging.info("no gpu, continue")
result = subprocess.run(sys.argv[1:])
logging.info("exitcode: %d", result.returncode)

sys.exit(result.returncode)
else:
from jit_sys_path_setup import setup_jit_cache

setup_jit_cache()

device_name, _ = cuda_info
require_count = int(
os.environ.get("WORLD_SIZE", os.environ.get("GPU_COUNT", "1"))
Expand Down
Loading