Skip to content

Commit 57822f5

Browse files
committed
Revert "[main] remove pytorch-triton-rocm (#90)"
This reverts commit a6ac118.
1 parent 06e5126 commit 57822f5

File tree

1 file changed

+12
-21
lines changed

1 file changed

+12
-21
lines changed

manywheel/build_rocm.sh

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ ROCM_VERSION_WITH_PATCH=rocm${ROCM_VERSION_MAJOR}.${ROCM_VERSION_MINOR}.${ROCM_V
7575
ROCM_INT=$(($ROCM_VERSION_MAJOR * 10000 + $ROCM_VERSION_MINOR * 100 + $ROCM_VERSION_PATCH))
7676

7777
PYTORCH_VERSION=$(cat $PYTORCH_ROOT/version.txt | grep -oP "[0-9]+\.[0-9]+\.[0-9]+")
78-
PYTORCH_VERSION_FULL=$(cat "$PYTORCH_ROOT/version.txt")
78+
7979
do_lightweight_build() {
8080
echo "=== Building LIGHTWEIGHT variant ==="
8181

@@ -348,34 +348,25 @@ ver() {
348348
# Assuming PYTORCH_VERSION=x.y.z, if x >= 2
349349
if [ ${PYTORCH_VERSION%%\.*} -ge 2 ]; then
350350
if [[ $(uname) == "Linux" ]] && [[ "$DESIRED_PYTHON" != "3.12" || $(ver $PYTORCH_VERSION) -ge $(ver 2.4) ]]; then
351-
# Triton commit got unified in PyTorch 2.5
352-
if [[ $(ver $PYTORCH_VERSION) -ge $(ver 2.5) ]]; then
351+
# Triton commit got unified in PyTorch 2.5
352+
if [[ $(ver $PYTORCH_VERSION) -ge $(ver 2.5) ]]; then
353353
TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt)
354-
else
354+
else
355355
TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt)
356-
fi
356+
fi
357357
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
358-
# Only linux Python < 3.13 are supported wheels for triton
359-
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64'$(if [[ $(ver "$PYTORCH_VERSION") -le $(ver "2.5") ]]; then echo " and python_version < '3.13'"; fi)"
360-
# Use "triton" for dev builds, else "pytorch-triton-rocm"
361-
# Temp: Currently enabling for rocm7.1_internal_testing branch only but plan to expand it to other branches
362-
if [[ "$PYTORCH_VERSION_FULL" == *"2.9.0a0"* ]]; then
363-
PKG="triton"
364-
else
365-
PKG="pytorch-triton-rocm"
366-
fi
367-
368-
REQ="${PKG}==${TRITON_VERSION}+${ROCM_VERSION_WITH_PATCH}.git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
369-
370-
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
371-
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${REQ}"
358+
# Only linux Python < 3.13 are supported wheels for triton
359+
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64'$(if [[ $(ver "$PYTORCH_VERSION") -le $(ver "2.5") ]]; then echo " and python_version < '3.13'"; fi)"
360+
361+
if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then
362+
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="pytorch-triton-rocm==${TRITON_VERSION}+${ROCM_VERSION_WITH_PATCH}.git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
372363
else
373-
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${REQ}"
364+
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | pytorch-triton-rocm==${TRITON_VERSION}+${ROCM_VERSION_WITH_PATCH}.git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
374365
fi
375-
unset PKG REQ
376366
fi
377367
fi
378368

369+
379370
echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}"
380371

381372
export LIGHTWEIGHT_WHEELNAME_MARKER="${LIGHTWEIGHT_WHEELNAME_MARKER}"

0 commit comments

Comments
 (0)