@@ -119,13 +119,15 @@ def get_torch_version(sha: str | None = None) -> str:
119119 )
120120 parser .add_argument ("--cuda-version" , "--cuda_version" , type = str )
121121 parser .add_argument ("--hip-version" , "--hip_version" , type = str )
122+ parser .add_argument ("--rocm-version" , "--rocm_version" , type = str )
122123 parser .add_argument ("--xpu-version" , "--xpu_version" , type = str )
123124
124125 args = parser .parse_args ()
125126
126127 assert args .is_debug is not None
127128 args .cuda_version = None if args .cuda_version == "" else args .cuda_version
128129 args .hip_version = None if args .hip_version == "" else args .hip_version
130+ args .rocm_version = None if args .rocm_version == "" else args .rocm_version
129131 args .xpu_version = None if args .xpu_version == "" else args .xpu_version
130132
131133 pytorch_root = Path (__file__ ).parent .parent
@@ -141,7 +143,7 @@ def get_torch_version(sha: str | None = None) -> str:
141143 with open (version_path , "w" ) as f :
142144 f .write ("from typing import Optional\n \n " )
143145 f .write (
144- "__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'xpu']\n "
146+ "__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'rocm', ' xpu']\n "
145147 )
146148 f .write (f"__version__ = '{ version } '\n " )
147149 # NB: This is not 100% accurate, because you could have built the
@@ -151,4 +153,5 @@ def get_torch_version(sha: str | None = None) -> str:
151153 f .write (f"cuda: Optional[str] = { repr (args .cuda_version )} \n " )
152154 f .write (f"git_version = { repr (sha )} \n " )
153155 f .write (f"hip: Optional[str] = { repr (args .hip_version )} \n " )
156+ f .write (f"rocm: Optional[str] = { repr (args .rocm_version )} \n " )
154157 f .write (f"xpu: Optional[str] = { repr (args .xpu_version )} \n " )
0 commit comments