Skip to content

Commit d2170fd

Browse files
committed
create torch.version.rocm variable to store rocm version
1 parent d0dbe05 commit d2170fd

File tree

3 files changed

+6
-1
lines changed

3 files changed

+6
-1
lines changed

tools/generate_torch_version.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,15 @@ def get_torch_version(sha: str | None = None) -> str:
119119
)
120120
parser.add_argument("--cuda-version", "--cuda_version", type=str)
121121
parser.add_argument("--hip-version", "--hip_version", type=str)
122+
parser.add_argument("--rocm-version", "--rocm_version", type=str)
122123
parser.add_argument("--xpu-version", "--xpu_version", type=str)
123124

124125
args = parser.parse_args()
125126

126127
assert args.is_debug is not None
127128
args.cuda_version = None if args.cuda_version == "" else args.cuda_version
128129
args.hip_version = None if args.hip_version == "" else args.hip_version
130+
args.rocm_version = None if args.rocm_version == "" else args.rocm_version
129131
args.xpu_version = None if args.xpu_version == "" else args.xpu_version
130132

131133
pytorch_root = Path(__file__).parent.parent
@@ -141,7 +143,7 @@ def get_torch_version(sha: str | None = None) -> str:
141143
with open(version_path, "w") as f:
142144
f.write("from typing import Optional\n\n")
143145
f.write(
144-
"__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'xpu']\n"
146+
"__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'rocm', 'xpu']\n"
145147
)
146148
f.write(f"__version__ = '{version}'\n")
147149
# NB: This is not 100% accurate, because you could have built the
@@ -151,4 +153,5 @@ def get_torch_version(sha: str | None = None) -> str:
151153
f.write(f"cuda: Optional[str] = {repr(args.cuda_version)}\n")
152154
f.write(f"git_version = {repr(sha)}\n")
153155
f.write(f"hip: Optional[str] = {repr(args.hip_version)}\n")
156+
f.write(f"rocm: Optional[str] = {repr(args.rocm_version)}\n")
154157
f.write(f"xpu: Optional[str] = {repr(args.xpu_version)}\n")

torch/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ add_custom_target(
491491
--is-debug=${TORCH_VERSION_DEBUG}
492492
--cuda-version=${CUDA_VERSION}
493493
--hip-version=${HIP_VERSION}
494+
--rocm-version=${ROCM_VERSION_DEV}
494495
--xpu-version=${SYCL_COMPILER_VERSION}
495496
BYPRODUCTS ${TORCH_SRC_DIR}/version.py
496497
COMMENT "Regenerating version file..."

torch/version.py.tpl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ cuda = '{{CUDA_VERSION}}'
44
# TODO: use workspace status to stamp the correct version
55
git_version = ""
66
hip = None
7+
rocm = None
78

89
# This is a gross monkey-patch hack that depends on the order of imports
910
# in torch/__init__.py

0 commit comments

Comments
 (0)