@@ -16,6 +16,9 @@ inputs:
1616 ref :
1717 description : Branch, tag, commit id
1818 default : " "
19+ mode :
20+ description : Source or wheels
21+ default : source
1922runs :
2023 using : " composite"
2124 steps :
7174 - name : Generate PyTorch cache key
7275 shell : bash
7376 run : |
74- PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} | sha256sum - | cut -d\ -f1)
77+ PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} ${{ inputs.mode }} | sha256sum - | cut -d\ -f1)
7578 echo "PYTORCH_CACHE_KEY=$PYTORCH_CACHE_KEY" | tee -a "$GITHUB_ENV"
7679
7780 - name : Load PyTorch from a cache
@@ -90,11 +93,12 @@ runs:
9093 with :
9194 repository : ${{ env.PYTORCH_REPO }}
9295 ref : ${{ env.PYTORCH_COMMIT_ID }}
93- submodules : recursive
96+ # To build PyTorch from source we need all submodules, they are not required for benchmarks
97+ submodules : ${{ inputs.mode == 'source' && 'recursive' || 'false' }}
9498 path : pytorch
9599
96100 - name : Apply additional PR patches
97- if : ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' }}
101+ if : ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' && inputs.mode == 'source' }}
98102 shell : bash
99103 run : |
100104 cd pytorch
@@ -108,7 +112,7 @@ runs:
108112 pip install 'numpy<2.0.0'
109113
110114 - name : Build PyTorch
111- if : ${{ steps.pytorch-cache.outputs.status == 'miss' }}
115+ if : ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.mode == 'source' }}
112116 shell : bash
113117 run : |
114118 source ${{ inputs.oneapi }}/setvars.sh
@@ -117,11 +121,24 @@ runs:
117121 pip install -r requirements.txt
118122 python setup.py bdist_wheel
119123
120- - name : Install PyTorch
124+ - name : Install PyTorch (built from source)
125+ if : ${{ inputs.mode == 'source' }}
121126 shell : bash
122127 run : |
123128 source ${{ inputs.oneapi }}/setvars.sh
124129 pip install pytorch/dist/*.whl
130+
131+ - name : Install PyTorch (from wheels)
132+ if : ${{ inputs.mode == 'wheels' }}
133+ shell : bash
134+ run : |
135+ source ${{ inputs.oneapi }}/setvars.sh
136+ pip install torch --index-url https://download.pytorch.org/whl/nightly/xpu
137+
138+ - name : Get PyTorch version
139+ shell : bash
140+ run : |
141+ source ${{ inputs.oneapi }}/setvars.sh
125142 PYTORCH_VERSION="$(python -c 'import torch;print(torch.__version__)')"
126143 echo "PYTORCH_VERSION=$PYTORCH_VERSION" | tee -a "$GITHUB_ENV"
127144
0 commit comments