diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..c34025772 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include transformer_engine/common/include *.* diff --git a/setup.py b/setup.py index b28644e03..f002e2edf 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -13,12 +13,15 @@ import subprocess import setuptools +from setuptools.command.egg_info import egg_info from wheel.bdist_wheel import bdist_wheel from build_tools.build_ext import CMakeExtension, get_build_ext from build_tools.te_version import te_version from build_tools.utils import ( rocm_build, + all_files_in_dir, + hipify, cuda_archs, get_frameworks, remove_dups, @@ -37,6 +40,15 @@ elif "jax" in frameworks: from pybind11.setup_helpers import build_ext as BuildExtension +class HipifyMeta(egg_info): + """Custom egg_info command to hipify source files before packaging.""" + + def run(self): + if rocm_build(): + print("Running hipification of installable headers for ROCm build...") + common_headers_dir = current_file_path / "transformer_engine/common/include" + hipify(current_file_path, common_headers_dir, all_files_in_dir(common_headers_dir), []) + super().run() CMakeBuildExtension = get_build_ext(BuildExtension) if not rocm_build(): @@ -212,7 +224,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: long_description=long_description, long_description_content_type="text/x-rst", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, + cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=">=3.8", classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires,