Skip to content

Commit 11258f5

Browse files
authored
Option to install upstream PyTorch from nightly wheels (#2386)
Adding a new workflow input to select how to install upstream PyTorch: build from sources (with the corresponding patches applied) or from the latest nightly wheels from https://download.pytorch.org/whl/nightly/xpu. The default is to build from source. Fixes #1913.
1 parent cac829d commit 11258f5

File tree

4 files changed

+43
-5
lines changed

4 files changed

+43
-5
lines changed

.github/actions/setup-pytorch/action.yml

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ inputs:
1616
ref:
1717
description: Branch, tag, commit id
1818
default: ""
19+
mode:
20+
description: Source or wheels
21+
default: source
1922
runs:
2023
using: "composite"
2124
steps:
@@ -71,7 +74,7 @@ runs:
7174
- name: Generate PyTorch cache key
7275
shell: bash
7376
run: |
74-
PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} | sha256sum - | cut -d\ -f1)
77+
PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} ${{ inputs.mode }} | sha256sum - | cut -d\ -f1)
7578
echo "PYTORCH_CACHE_KEY=$PYTORCH_CACHE_KEY" | tee -a "$GITHUB_ENV"
7679
7780
- name: Load PyTorch from a cache
@@ -90,11 +93,12 @@ runs:
9093
with:
9194
repository: ${{ env.PYTORCH_REPO }}
9295
ref: ${{ env.PYTORCH_COMMIT_ID }}
93-
submodules: recursive
96+
# To build PyTorch from source we need all submodules, they are not required for benchmarks
97+
submodules: ${{ inputs.mode == 'source' && 'recursive' || 'false' }}
9498
path: pytorch
9599

96100
- name: Apply additional PR patches
97-
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' }}
101+
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' && inputs.mode == 'source' }}
98102
shell: bash
99103
run: |
100104
cd pytorch
@@ -108,7 +112,7 @@ runs:
108112
pip install 'numpy<2.0.0'
109113
110114
- name: Build PyTorch
111-
if: ${{ steps.pytorch-cache.outputs.status == 'miss' }}
115+
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.mode == 'source' }}
112116
shell: bash
113117
run: |
114118
source ${{ inputs.oneapi }}/setvars.sh
@@ -117,11 +121,24 @@ runs:
117121
pip install -r requirements.txt
118122
python setup.py bdist_wheel
119123
120-
- name: Install PyTorch
124+
- name: Install PyTorch (built from source)
125+
if: ${{ inputs.mode == 'source' }}
121126
shell: bash
122127
run: |
123128
source ${{ inputs.oneapi }}/setvars.sh
124129
pip install pytorch/dist/*.whl
130+
131+
- name: Install PyTorch (from wheels)
132+
if: ${{ inputs.mode == 'wheels' }}
133+
shell: bash
134+
run: |
135+
source ${{ inputs.oneapi }}/setvars.sh
136+
pip install torch --index-url https://download.pytorch.org/whl/nightly/xpu
137+
138+
- name: Get PyTorch version
139+
shell: bash
140+
run: |
141+
source ${{ inputs.oneapi }}/setvars.sh
125142
PYTORCH_VERSION="$(python -c 'import torch;print(torch.__version__)')"
126143
echo "PYTORCH_VERSION=$PYTORCH_VERSION" | tee -a "$GITHUB_ENV"
127144

.github/workflows/build-test-gpu.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ on:
1212
description: PyTorch ref, keep empty for default
1313
type: string
1414
default: ""
15+
pytorch_mode:
16+
description: PyTorch mode, source or wheels
17+
type: choice
18+
options:
19+
- source
20+
- wheels
21+
default: source
1522
upload_test_reports:
1623
description: Upload test reports
1724
type: boolean
@@ -46,6 +53,7 @@ jobs:
4653
device: ${{ inputs.runner_label }}
4754
runner_label: ${{ inputs.runner_label }}
4855
pytorch_ref: ${{ inputs.pytorch_ref }}
56+
pytorch_mode: ${{ inputs.pytorch_mode || 'source' }}
4957
python_version: ${{ matrix.python }}
5058
upload_test_reports: ${{ inputs.upload_test_reports }}
5159
ignore_errors: ${{ inputs.ignore_errors }}

.github/workflows/build-test-reusable.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ on:
2020
description: PyTorch ref, keep empty for default
2121
type: string
2222
default: ""
23+
pytorch_mode:
24+
description: PyTorch mode, source or wheels
25+
type: string
26+
default: "source"
2327
python_version:
2428
description: Python version
2529
type: string
@@ -96,6 +100,7 @@ jobs:
96100
with:
97101
repository: pytorch/pytorch
98102
ref: ${{ inputs.pytorch_ref }}
103+
mode: ${{ inputs.pytorch_mode }}
99104

100105
- name: Install pass_rate dependencies
101106
run: |

.github/workflows/build-test.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ on:
1212
description: PyTorch ref, keep empty for default
1313
type: string
1414
default: ""
15+
pytorch_mode:
16+
description: PyTorch mode, source or wheels
17+
type: choice
18+
options:
19+
- source
20+
- wheels
21+
default: source
1522
upload_test_reports:
1623
description: Upload test reports
1724
type: boolean
@@ -120,6 +127,7 @@ jobs:
120127
driver_version: ${{ matrix.driver }}
121128
runner_label: ${{ inputs.runner_label }}
122129
pytorch_ref: ${{ inputs.pytorch_ref }}
130+
pytorch_mode: ${{ inputs.pytorch_mode || 'source' }}
123131
python_version: ${{ matrix.python }}
124132
upload_test_reports: ${{ inputs.upload_test_reports || false }}
125133
ignore_errors: ${{ inputs.ignore_errors || false }}

0 commit comments

Comments
 (0)