11#!/usr/bin/env python3
22
33import os
4+ import re
45import shutil
56import sys
67from pathlib import Path
@@ -50,6 +51,30 @@ def patch_init_py(
5051 with open (path , "w" ) as f :
5152 f .write (orig )
5253
54+ def get_rocm_version () -> str :
55+ rocm_path = os .environ .get ('ROCM_HOME' ) or os .environ .get ('ROCM_PATH' ) or "/opt/rocm"
56+ rocm_version = "0.0.0"
57+ rocm_version_h = f"{ rocm_path } /include/rocm-core/rocm_version.h"
58+ if not os .path .isfile (rocm_version_h ):
59+ rocm_version_h = f"{ rocm_path } /include/rocm_version.h"
60+ # The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
61+ if os .path .isfile (rocm_version_h ):
62+ RE_MAJOR = re .compile (r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)" )
63+ RE_MINOR = re .compile (r"#define\s+ROCM_VERSION_MINOR\s+(\d+)" )
64+ RE_PATCH = re .compile (r"#define\s+ROCM_VERSION_PATCH\s+(\d+)" )
65+ major , minor , patch = 0 , 0 , 0
66+ for line in open (rocm_version_h ):
67+ match = RE_MAJOR .search (line )
68+ if match :
69+ major = int (match .group (1 ))
70+ match = RE_MINOR .search (line )
71+ if match :
72+ minor = int (match .group (1 ))
73+ match = RE_PATCH .search (line )
74+ if match :
75+ patch = int (match .group (1 ))
76+ rocm_version = str (major )+ "." + str (minor )+ "." + str (patch )
77+ return rocm_version
5378
5479def build_triton (
5580 * ,
@@ -64,7 +89,12 @@ def build_triton(
6489 if "MAX_JOBS" not in env :
6590 max_jobs = os .cpu_count () or 1
6691 env ["MAX_JOBS" ] = str (max_jobs )
67-
92+ if not release :
93+ # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
94+ # while release build should only include the version, i.e. 2.1.0
95+ rocm_version = get_rocm_version ()
96+ version_suffix = f"+rocm{ rocm_version } .git{ commit_hash [:8 ]} "
97+ version += version_suffix
6898 with TemporaryDirectory () as tmpdir :
6999 triton_basedir = Path (tmpdir ) / "triton"
70100 triton_pythondir = triton_basedir / "python"
@@ -89,6 +119,7 @@ def build_triton(
89119
90120 # change built wheel name and version
91121 env ["TRITON_WHEEL_NAME" ] = triton_pkg_name
122+ env ["TRITON_WHEEL_VERSION_SUFFIX" ] = version_suffix
92123 if with_clang_ldd :
93124 env ["TRITON_BUILD_WITH_CLANG_LLD" ] = "1"
94125
0 commit comments