Skip to content

Commit c9268e9

Browse files
jithunnair-amdethanwee1
authored andcommitted
[release/2.7] Enable wheels
(cherry picked from commit 93864a8 with modifications for release/2.7) Reintroduce CIRCLE_TAG to be able to set PYTORCH_BUILD_VERSION without date This logic was present until release/2.2 (https://github.com/ROCm/pytorch/blob/4cd7f3ac9078ed449b8ae096887125f9b3b30659/.circleci/scripts/binary_populate_env.sh#L14) but was removed in release/2.3 (cherry picked from commit abe69fe) (cherry picked from commit dc9563a) (cherry picked from commit 88b9764)
1 parent 1341794 commit c9268e9

File tree

4 files changed

+42
-6
lines changed

4 files changed

+42
-6
lines changed

.circleci/scripts/binary_populate_env.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ export TZ=UTC
55
tagged_version() {
66
GIT_DIR="${workdir}/pytorch/.git"
77
GIT_DESCRIBE="git --git-dir ${GIT_DIR} describe --tags --match v[0-9]*.[0-9]*.[0-9]*"
8-
if [[ ! -d "${GIT_DIR}" ]]; then
8+
if [[ -n "${CIRCLE_TAG:-}" ]]; then
9+
echo "${CIRCLE_TAG}"
10+
elif [[ ! -d "${GIT_DIR}" ]]; then
911
echo "Abort, abort! Git dir ${GIT_DIR} does not exists!"
1012
kill $$
1113
elif ${GIT_DESCRIBE} --exact >/dev/null; then
@@ -69,6 +71,8 @@ fi
6971

7072
export PYTORCH_BUILD_NUMBER=1
7173

74+
# This part is done in the builder scripts so commenting the duplicate code
75+
: <<'BLOCK_COMMENT'
7276
# Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS
7377
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
7478
@@ -116,6 +120,7 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_B
116120
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}"
117121
fi
118122
fi
123+
BLOCK_COMMENT
119124

120125
USE_GLOO_WITH_OPENSSL="ON"
121126
if [[ "$GPU_ARCH_TYPE" =~ .*aarch64.* ]]; then

.github/scripts/build_triton_wheel.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
import os
4+
import re
45
import shutil
56
import sys
67
from pathlib import Path
@@ -47,6 +48,30 @@ def patch_init_py(
4748
with open(path, "w") as f:
4849
f.write(orig)
4950

51+
def get_rocm_version() -> str:
52+
rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
53+
rocm_version = "0.0.0"
54+
rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
55+
if not os.path.isfile(rocm_version_h):
56+
rocm_version_h = f"{rocm_path}/include/rocm_version.h"
57+
# The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
58+
if os.path.isfile(rocm_version_h):
59+
RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
60+
RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
61+
RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
62+
major, minor, patch = 0, 0, 0
63+
for line in open(rocm_version_h):
64+
match = RE_MAJOR.search(line)
65+
if match:
66+
major = int(match.group(1))
67+
match = RE_MINOR.search(line)
68+
if match:
69+
minor = int(match.group(1))
70+
match = RE_PATCH.search(line)
71+
if match:
72+
patch = int(match.group(1))
73+
rocm_version = str(major)+"."+str(minor)+"."+str(patch)
74+
return rocm_version
5075

5176
def build_triton(
5277
*,
@@ -61,7 +86,12 @@ def build_triton(
6186
if "MAX_JOBS" not in env:
6287
max_jobs = os.cpu_count() or 1
6388
env["MAX_JOBS"] = str(max_jobs)
64-
89+
if not release:
90+
# Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
91+
# while release build should only include the version, i.e. 2.1.0
92+
rocm_version = get_rocm_version()
93+
version_suffix = f"+rocm{rocm_version}.git{commit_hash[:8]}"
94+
version += version_suffix
6595
with TemporaryDirectory() as tmpdir:
6696
triton_basedir = Path(tmpdir) / "triton"
6797
triton_pythondir = triton_basedir / "python"
@@ -84,6 +114,7 @@ def build_triton(
84114

85115
# change built wheel name and version
86116
env["TRITON_WHEEL_NAME"] = triton_pkg_name
117+
env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
87118
if with_clang_ldd:
88119
env["TRITON_BUILD_WITH_CLANG_LLD"] = "1"
89120

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ set(CMAKE_C_STANDARD
5454
# ---[ Utils
5555
include(cmake/public/utils.cmake)
5656

57-
# --- [ Check that minimal gcc version is 9.3+
58-
if(CMAKE_COMPILER_IS_GNUCXX AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.3)
57+
# --- [ Check that minimal gcc version is 9.2+
58+
if(CMAKE_COMPILER_IS_GNUCXX AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.2)
5959
message(
6060
FATAL_ERROR
61-
"GCC-9.3 or newer is required to compile PyTorch, but found ${CMAKE_CXX_COMPILER_VERSION}"
61+
"GCC-9.2 or newer is required to compile PyTorch, but found ${CMAKE_CXX_COMPILER_VERSION}"
6262
)
6363
endif()
6464

version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.7.0a0
1+
2.7.0

0 commit comments

Comments
 (0)