Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cmake/public/LoadHIP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ if(HIP_FOUND)
set(PYTORCH_FOUND_HIP TRUE)
find_package_and_print_version(hip REQUIRED CONFIG)

if(HIP_VERSION)
# Check if HIP_VERSION contains a dash (e.g., "7.1.25421-32f9fa6ca5")
# and strip everything after it to get clean numeric version
string(FIND "${HIP_VERSION}" "-" DASH_POS)
if(NOT DASH_POS EQUAL -1)
string(SUBSTRING "${HIP_VERSION}" 0 ${DASH_POS} HIP_VERSION_CLEAN)
set(HIP_VERSION "${HIP_VERSION_CLEAN}")
endif()
message("HIP version: ${HIP_VERSION}")
endif()

# The rocm-core package was only introduced in ROCm 6.4, so we make it optional.
find_package(rocm-core CONFIG)

Expand Down
5 changes: 4 additions & 1 deletion tools/generate_torch_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,15 @@ def get_torch_version(sha: str | None = None) -> str:
)
parser.add_argument("--cuda-version", "--cuda_version", type=str)
parser.add_argument("--hip-version", "--hip_version", type=str)
parser.add_argument("--rocm-version", "--rocm_version", type=str)
parser.add_argument("--xpu-version", "--xpu_version", type=str)

args = parser.parse_args()

assert args.is_debug is not None
args.cuda_version = None if args.cuda_version == "" else args.cuda_version
args.hip_version = None if args.hip_version == "" else args.hip_version
args.rocm_version = None if args.rocm_version == "" else args.rocm_version
args.xpu_version = None if args.xpu_version == "" else args.xpu_version

pytorch_root = Path(__file__).parent.parent
Expand All @@ -141,7 +143,7 @@ def get_torch_version(sha: str | None = None) -> str:
with open(version_path, "w") as f:
f.write("from typing import Optional\n\n")
f.write(
"__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'xpu']\n"
"__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'rocm', 'xpu']\n"
)
f.write(f"__version__ = '{version}'\n")
# NB: This is not 100% accurate, because you could have built the
Expand All @@ -151,4 +153,5 @@ def get_torch_version(sha: str | None = None) -> str:
f.write(f"cuda: Optional[str] = {repr(args.cuda_version)}\n")
f.write(f"git_version = {repr(sha)}\n")
f.write(f"hip: Optional[str] = {repr(args.hip_version)}\n")
f.write(f"rocm: Optional[str] = {repr(args.rocm_version)}\n")
f.write(f"xpu: Optional[str] = {repr(args.xpu_version)}\n")
3 changes: 2 additions & 1 deletion torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,8 @@ add_custom_target(
"${Python_EXECUTABLE}" "${TOOLS_PATH}/generate_torch_version.py"
--is-debug=${TORCH_VERSION_DEBUG}
--cuda-version=${CUDA_VERSION}
--hip-version=${ROCM_VERSION_DEV}
--hip-version=${HIP_VERSION}
--rocm-version=${ROCM_VERSION_DEV}
--xpu-version=${SYCL_COMPILER_VERSION}
BYPRODUCTS ${TORCH_SRC_DIR}/version.py
COMMENT "Regenerating version file..."
Expand Down
1 change: 1 addition & 0 deletions torch/version.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ cuda = '{{CUDA_VERSION}}'
# TODO: use workspace status to stamp the correct version
git_version = ""
hip = None
rocm = None

# This is a gross monkey-patch hack that depends on the order of imports
# in torch/__init__.py
Expand Down