Skip to content

Commit 0553283

Browse files
jithunnair-amdjeffdaily
authored andcommitted
[release/2.8] Enable wheels
(cherry picked from commit e294d4d with modifications for release/2.8) Reintroduce CIRCLE_TAG to be able to set PYTORCH_BUILD_VERSION without date (cherry picked from commit 71a30ea)
1 parent d38164a commit 0553283

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

.circleci/scripts/binary_populate_env.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ export TZ=UTC
55
tagged_version() {
66
GIT_DIR="${workdir}/pytorch/.git"
77
GIT_DESCRIBE="git --git-dir ${GIT_DIR} describe --tags --match v[0-9]*.[0-9]*.[0-9]*"
8-
if [[ ! -d "${GIT_DIR}" ]]; then
8+
if [[ -n "${CIRCLE_TAG:-}" ]]; then
9+
echo "${CIRCLE_TAG}"
10+
elif [[ ! -d "${GIT_DIR}" ]]; then
911
echo "Abort, abort! Git dir ${GIT_DIR} does not exists!"
1012
kill $$
1113
elif ${GIT_DESCRIBE} --exact >/dev/null; then
@@ -69,6 +71,8 @@ fi
6971

7072
export PYTORCH_BUILD_NUMBER=1
7173

74+
# This part is done in the builder scripts so commenting the duplicate code
75+
: <<'BLOCK_COMMENT'
7276
# Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS
7377
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
7478
TRITON_CONSTRAINT="platform_system == 'Linux'"
@@ -110,6 +114,7 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_B
110114
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}"
111115
fi
112116
fi
117+
BLOCK_COMMENT
113118

114119
USE_GLOO_WITH_OPENSSL="ON"
115120
if [[ "$GPU_ARCH_TYPE" =~ .*aarch64.* ]]; then

.github/scripts/build_triton_wheel.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
import os
4+
import re
45
import shutil
56
import sys
67
from pathlib import Path
@@ -50,6 +51,30 @@ def patch_init_py(
5051
with open(path, "w") as f:
5152
f.write(orig)
5253

54+
def get_rocm_version() -> str:
55+
rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
56+
rocm_version = "0.0.0"
57+
rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
58+
if not os.path.isfile(rocm_version_h):
59+
rocm_version_h = f"{rocm_path}/include/rocm_version.h"
60+
# The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
61+
if os.path.isfile(rocm_version_h):
62+
RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
63+
RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
64+
RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
65+
major, minor, patch = 0, 0, 0
66+
for line in open(rocm_version_h):
67+
match = RE_MAJOR.search(line)
68+
if match:
69+
major = int(match.group(1))
70+
match = RE_MINOR.search(line)
71+
if match:
72+
minor = int(match.group(1))
73+
match = RE_PATCH.search(line)
74+
if match:
75+
patch = int(match.group(1))
76+
rocm_version = str(major)+"."+str(minor)+"."+str(patch)
77+
return rocm_version
5378

5479
def build_triton(
5580
*,
@@ -64,7 +89,12 @@ def build_triton(
6489
if "MAX_JOBS" not in env:
6590
max_jobs = os.cpu_count() or 1
6691
env["MAX_JOBS"] = str(max_jobs)
67-
92+
if not release:
93+
# Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
94+
# while release build should only include the version, i.e. 2.1.0
95+
rocm_version = get_rocm_version()
96+
version_suffix = f"+rocm{rocm_version}.git{commit_hash[:8]}"
97+
version += version_suffix
6898
with TemporaryDirectory() as tmpdir:
6999
triton_basedir = Path(tmpdir) / "triton"
70100
triton_pythondir = triton_basedir / "python"
@@ -89,6 +119,7 @@ def build_triton(
89119

90120
# change built wheel name and version
91121
env["TRITON_WHEEL_NAME"] = triton_pkg_name
122+
env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
92123
if with_clang_ldd:
93124
env["TRITON_BUILD_WITH_CLANG_LLD"] = "1"
94125

0 commit comments

Comments
 (0)