Skip to content

Commit 3a79d03

Browse files
authored
[https://nvbugs/5617275][fix] Extract py files from prebuilt wheel for editable installs (#8738)
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent aecc965 commit 3a79d03

File tree

2 files changed

+84
-5
lines changed

2 files changed

+84
-5
lines changed

setup.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,11 @@ def has_ext_modules(self):
118118
'libs/libdecoder_attention_1.so', 'libs/nvshmem/License.txt',
119119
'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
120120
'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so',
121-
'deep_ep/LICENSE', 'deep_ep_cpp_tllm.*.so', "include/**/*",
122-
'deep_gemm/LICENSE', 'deep_gemm/include/**/*',
123-
'deep_gemm_cpp_tllm.*.so', 'scripts/install_tensorrt.sh',
124-
'flash_mla/LICENSE', 'flash_mla_cpp_tllm.*.so'
121+
'deep_ep/LICENSE', 'deep_ep/*.py', 'deep_ep_cpp_tllm.*.so',
122+
"include/**/*", 'deep_gemm/LICENSE', 'deep_gemm/include/**/*',
123+
'deep_gemm/*.py', 'deep_gemm_cpp_tllm.*.so',
124+
'scripts/install_tensorrt.sh', 'flash_mla/LICENSE', 'flash_mla/*.py',
125+
'flash_mla_cpp_tllm.*.so'
125126
]
126127

127128
package_data += [
@@ -203,8 +204,19 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
203204

204205
with zipfile.ZipFile(wheel_path) as wheel:
205206
for file in wheel.filelist:
206-
if file.filename.endswith((".py", ".yaml")):
207+
# Skip yaml files
208+
if file.filename.endswith(".yaml"):
207209
continue
210+
211+
# Skip .py files EXCEPT for generated C++ extension wrappers
212+
# (deep_gemm, deep_ep, flash_mla Python files are generated during build)
213+
if file.filename.endswith(".py"):
214+
allowed_dirs = ("tensorrt_llm/deep_gemm/",
215+
"tensorrt_llm/deep_ep/",
216+
"tensorrt_llm/flash_mla/")
217+
if not any(file.filename.startswith(d) for d in allowed_dirs):
218+
continue
219+
208220
for filename_pattern in package_data:
209221
if fnmatch.fnmatchcase(file.filename,
210222
f"tensorrt_llm/{filename_pattern}"):
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
Test that prebuilt wheel extraction includes all necessary Python files.
3+
4+
"""
5+
from pathlib import Path
6+
7+
8+
def test_cpp_extension_wrapper_files_exist():
9+
"""Verify that C++ extension wrapper Python files were extracted from prebuilt wheel."""
10+
import tensorrt_llm
11+
12+
trtllm_root = Path(tensorrt_llm.__file__).parent
13+
14+
# C++ extensions that have Python wrapper files generated during build
15+
required_files = {
16+
'deep_gemm':
17+
['__init__.py', 'testing/__init__.py', 'utils/__init__.py'],
18+
'deep_ep': ['__init__.py', 'buffer.py', 'utils.py'],
19+
'flash_mla': ['__init__.py', 'flash_mla_interface.py'],
20+
}
21+
22+
missing_files = []
23+
for ext_dir, files in required_files.items():
24+
for file in files:
25+
file_path = trtllm_root / ext_dir / file
26+
if not file_path.exists():
27+
missing_files.append(str(file_path.relative_to(trtllm_root)))
28+
29+
assert not missing_files, (
30+
f"Missing Python wrapper files for C++ extensions: {missing_files}\n"
31+
f"This indicates setup.py may not be extracting Python files from prebuilt wheels.\n"
32+
f"Check setup.py extract_from_precompiled() function.")
33+
34+
35+
def test_cpp_extensions_importable():
36+
"""Verify that C++ extension wrappers can be imported successfully."""
37+
import_tests = [
38+
('tensorrt_llm.deep_gemm', 'fp8_mqa_logits'),
39+
('tensorrt_llm.deep_ep', 'Buffer'),
40+
('tensorrt_llm.flash_mla', 'flash_mla_interface'),
41+
]
42+
43+
failed_imports = []
44+
for module_name, attr_name in import_tests:
45+
try:
46+
module = __import__(module_name, fromlist=[attr_name])
47+
if not hasattr(module, attr_name):
48+
failed_imports.append(
49+
f"{module_name}.{attr_name} (attribute not found)")
50+
except ImportError as e:
51+
failed_imports.append(f"{module_name} (ImportError: {e})")
52+
53+
assert not failed_imports, (
54+
f"Failed to import C++ extension wrappers: {failed_imports}\n"
55+
f"This may indicate missing Python files or circular import issues.")
56+
57+
58+
if __name__ == '__main__':
59+
print("Testing C++ extension wrapper files...")
60+
test_cpp_extension_wrapper_files_exist()
61+
print("✅ All required Python files exist")
62+
63+
print("\nTesting C++ extension imports...")
64+
test_cpp_extensions_importable()
65+
print("✅ All imports successful")
66+
67+
print("\n✅ All tests passed!")

0 commit comments

Comments
 (0)