@@ -75,7 +75,7 @@ ROCM_VERSION_WITH_PATCH=rocm${ROCM_VERSION_MAJOR}.${ROCM_VERSION_MINOR}.${ROCM_V
7575ROCM_INT=$(( $ROCM_VERSION_MAJOR * 10000 + $ROCM_VERSION_MINOR * 100 + $ROCM_VERSION_PATCH ))
7676
7777PYTORCH_VERSION=$( cat $PYTORCH_ROOT /version.txt | grep -oP " [0-9]+\.[0-9]+\.[0-9]+" )
78- PYTORCH_VERSION_FULL= $( cat " $PYTORCH_ROOT /version.txt " )
78+
7979do_lightweight_build () {
8080 echo " === Building LIGHTWEIGHT variant ==="
8181
@@ -348,34 +348,25 @@ ver() {
348348# Assuming PYTORCH_VERSION=x.y.z, if x >= 2
349349if [ ${PYTORCH_VERSION%% \. * } -ge 2 ]; then
350350 if [[ $( uname) == " Linux" ]] && [[ " $DESIRED_PYTHON " != " 3.12" || $( ver $PYTORCH_VERSION ) -ge $( ver 2.4) ]]; then
351- # Triton commit got unified in PyTorch 2.5
352- if [[ $( ver $PYTORCH_VERSION ) -ge $( ver 2.5) ]]; then
351+ # Triton commit got unified in PyTorch 2.5
352+ if [[ $( ver $PYTORCH_VERSION ) -ge $( ver 2.5) ]]; then
353353 TRITON_SHORTHASH=$( cut -c1-8 $PYTORCH_ROOT /.ci/docker/ci_commit_pins/triton.txt)
354- else
354+ else
355355 TRITON_SHORTHASH=$( cut -c1-8 $PYTORCH_ROOT /.ci/docker/ci_commit_pins/triton-rocm.txt)
356- fi
356+ fi
357357 TRITON_VERSION=$( cat $PYTORCH_ROOT /.ci/docker/triton_version.txt)
358- # Only linux Python < 3.13 are supported wheels for triton
359- TRITON_CONSTRAINT=" platform_system == 'Linux' and platform_machine == 'x86_64'$( if [[ $( ver " $PYTORCH_VERSION " ) -le $( ver " 2.5" ) ]]; then echo " and python_version < '3.13'" ; fi) "
360- # Use "triton" for dev builds, else "pytorch-triton-rocm"
361- # Temp: Currently enabling for rocm7.1_internal_testing branch only but plan to expand it to other branches
362- if [[ " $PYTORCH_VERSION_FULL " == * " 2.9.0a0" * ]]; then
363- PKG=" triton"
364- else
365- PKG=" pytorch-triton-rocm"
366- fi
367-
368- REQ=" ${PKG} ==${TRITON_VERSION} +${ROCM_VERSION_WITH_PATCH} .git${TRITON_SHORTHASH} ; ${TRITON_CONSTRAINT} "
369-
370- if [[ -z " $PYTORCH_EXTRA_INSTALL_REQUIREMENTS " ]]; then
371- export PYTORCH_EXTRA_INSTALL_REQUIREMENTS=" ${REQ} "
358+ # Only linux Python < 3.13 are supported wheels for triton
359+ TRITON_CONSTRAINT=" platform_system == 'Linux' and platform_machine == 'x86_64'$( if [[ $( ver " $PYTORCH_VERSION " ) -le $( ver " 2.5" ) ]]; then echo " and python_version < '3.13'" ; fi) "
360+
361+ if [[ -z " $PYTORCH_EXTRA_INSTALL_REQUIREMENTS " ]]; then
362+ export PYTORCH_EXTRA_INSTALL_REQUIREMENTS=" pytorch-triton-rocm==${TRITON_VERSION} +${ROCM_VERSION_WITH_PATCH} .git${TRITON_SHORTHASH} ; ${TRITON_CONSTRAINT} "
372363 else
373- export PYTORCH_EXTRA_INSTALL_REQUIREMENTS=" ${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${REQ }"
364+ export PYTORCH_EXTRA_INSTALL_REQUIREMENTS=" ${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | pytorch-triton-rocm== ${TRITON_VERSION} + ${ROCM_VERSION_WITH_PATCH} .git ${TRITON_SHORTHASH} ; ${TRITON_CONSTRAINT }"
374365 fi
375- unset PKG REQ
376366 fi
377367fi
378368
369+
379370echo " PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH} "
380371
381372export LIGHTWEIGHT_WHEELNAME_MARKER=" ${LIGHTWEIGHT_WHEELNAME_MARKER} "
0 commit comments