@@ -75,7 +75,7 @@ ROCM_VERSION_WITH_PATCH=rocm${ROCM_VERSION_MAJOR}.${ROCM_VERSION_MINOR}.${ROCM_V
7575ROCM_INT=$(( $ROCM_VERSION_MAJOR * 10000 + $ROCM_VERSION_MINOR * 100 + $ROCM_VERSION_PATCH ))
7676
7777PYTORCH_VERSION=$( cat $PYTORCH_ROOT /version.txt | grep -oP " [0-9]+\.[0-9]+\.[0-9]+" )
78-
78+ PYTORCH_VERSION_FULL= $( cat " $PYTORCH_ROOT /version.txt " )
7979do_lightweight_build () {
8080 echo " === Building LIGHTWEIGHT variant ==="
8181
@@ -348,25 +348,34 @@ ver() {
348348# Assuming PYTORCH_VERSION=x.y.z, if x >= 2
349349if [ ${PYTORCH_VERSION%% \. * } -ge 2 ]; then
350350 if [[ $( uname) == " Linux" ]] && [[ " $DESIRED_PYTHON " != " 3.12" || $( ver $PYTORCH_VERSION ) -ge $( ver 2.4) ]]; then
351- # Triton commit got unified in PyTorch 2.5
352- if [[ $( ver $PYTORCH_VERSION ) -ge $( ver 2.5) ]]; then
351+ # Triton commit got unified in PyTorch 2.5
352+ if [[ $( ver $PYTORCH_VERSION ) -ge $( ver 2.5) ]]; then
353353 TRITON_SHORTHASH=$( cut -c1-8 $PYTORCH_ROOT /.ci/docker/ci_commit_pins/triton.txt)
354- else
354+ else
355355 TRITON_SHORTHASH=$( cut -c1-8 $PYTORCH_ROOT /.ci/docker/ci_commit_pins/triton-rocm.txt)
356- fi
356+ fi
357357 TRITON_VERSION=$( cat $PYTORCH_ROOT /.ci/docker/triton_version.txt)
358- # Only linux Python < 3.13 are supported wheels for triton
359- TRITON_CONSTRAINT=" platform_system == 'Linux' and platform_machine == 'x86_64'$( if [[ $( ver " $PYTORCH_VERSION " ) -le $( ver " 2.5" ) ]]; then echo " and python_version < '3.13'" ; fi) "
360-
361- if [[ -z " $PYTORCH_EXTRA_INSTALL_REQUIREMENTS " ]]; then
362- export PYTORCH_EXTRA_INSTALL_REQUIREMENTS=" pytorch-triton-rocm==${TRITON_VERSION} +${ROCM_VERSION_WITH_PATCH} .git${TRITON_SHORTHASH} ; ${TRITON_CONSTRAINT} "
358+ # Only linux Python < 3.13 are supported wheels for triton
359+ TRITON_CONSTRAINT=" platform_system == 'Linux' and platform_machine == 'x86_64'$( if [[ $( ver " $PYTORCH_VERSION " ) -le $( ver " 2.5" ) ]]; then echo " and python_version < '3.13'" ; fi) "
360+ # Use "triton" for dev builds, else "pytorch-triton-rocm"
361+ # Temp: Currently enabling for rocm7.1_internal_testing branch only but plan to expand it to other branches
362+ if [[ " $PYTORCH_VERSION_FULL " == * " 2.9.0a0" * ]]; then
363+ PKG=" triton"
363364 else
364- export PYTORCH_EXTRA_INSTALL_REQUIREMENTS= " ${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | pytorch-triton-rocm== ${TRITON_VERSION} + ${ROCM_VERSION_WITH_PATCH} .git ${TRITON_SHORTHASH} ; ${TRITON_CONSTRAINT} "
365+ PKG= " pytorch-triton-rocm"
365366 fi
367+
368+ REQ=" ${PKG} ==${TRITON_VERSION} +${ROCM_VERSION_WITH_PATCH} .git${TRITON_SHORTHASH} ; ${TRITON_CONSTRAINT} "
369+
370+ if [[ -z " $PYTORCH_EXTRA_INSTALL_REQUIREMENTS " ]]; then
371+ export PYTORCH_EXTRA_INSTALL_REQUIREMENTS=" ${REQ} "
372+ else
373+ export PYTORCH_EXTRA_INSTALL_REQUIREMENTS=" ${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${REQ} "
374+ fi
375+ unset PKG REQ
366376 fi
367377fi
368378
369-
370379echo " PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH} "
371380
372381export LIGHTWEIGHT_WHEELNAME_MARKER=" ${LIGHTWEIGHT_WHEELNAME_MARKER} "
0 commit comments