Skip to content

Commit 26351e5

Browse files
committed
Add pytorch_mode: source ot wheels
1 parent e4416dd commit 26351e5

File tree

4 files changed

+28
-3
lines changed

4 files changed

+28
-3
lines changed

.github/actions/setup-pytorch/action.yml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ inputs:
1616
ref:
1717
description: Branch, tag, commit id
1818
default: ""
19+
mode:
20+
description: Source or wheels
21+
default: source
1922
runs:
2023
using: "composite"
2124
steps:
@@ -71,7 +74,7 @@ runs:
7174
- name: Generate PyTorch cache key
7275
shell: bash
7376
run: |
74-
PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} | sha256sum - | cut -d\ -f1)
77+
PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} ${{ inputs.mode }} | sha256sum - | cut -d\ -f1)
7578
echo "PYTORCH_CACHE_KEY=$PYTORCH_CACHE_KEY" | tee -a "$GITHUB_ENV"
7679
7780
- name: Load PyTorch from a cache
@@ -94,7 +97,7 @@ runs:
9497
path: pytorch
9598

9699
- name: Apply additional PR patches
97-
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' }}
100+
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' && inputs.mode == 'source' }}
98101
shell: bash
99102
run: |
100103
cd pytorch
@@ -108,7 +111,7 @@ runs:
108111
pip install 'numpy<2.0.0'
109112
110113
- name: Build PyTorch
111-
if: ${{ steps.pytorch-cache.outputs.status == 'miss' }}
114+
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.mode == 'source' }}
112115
shell: bash
113116
run: |
114117
source ${{ inputs.oneapi }}/setvars.sh
@@ -117,6 +120,13 @@ runs:
117120
pip install -r requirements.txt
118121
python setup.py bdist_wheel
119122
123+
- name: Download PyTorch wheels
124+
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.mode == 'wheels' }}
125+
shell: bash
126+
run: |
127+
mkdir -p pytorch/dist
128+
pip download torch --index-url https://download.pytorch.org/whl/nightly/xpu --dest pytorch/dist/
129+
120130
- name: Install PyTorch
121131
shell: bash
122132
run: |

.github/workflows/build-test-gpu.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ on:
1212
description: PyTorch ref, keep empty for default
1313
type: string
1414
default: ""
15+
pytorch_mode:
16+
description: PyTorch mode, source or wheels
17+
type: string
18+
default: "source"
1519
upload_test_reports:
1620
description: Upload test reports
1721
type: boolean
@@ -46,6 +50,7 @@ jobs:
4650
device: ${{ inputs.runner_label }}
4751
runner_label: ${{ inputs.runner_label }}
4852
pytorch_ref: ${{ inputs.pytorch_ref }}
53+
pytorch_mode: ${{ inputs.pytorch_mode }}
4954
python_version: ${{ matrix.python }}
5055
upload_test_reports: ${{ inputs.upload_test_reports }}
5156
ignore_errors: ${{ inputs.ignore_errors }}

.github/workflows/build-test-reusable.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ on:
2020
description: PyTorch ref, keep empty for default
2121
type: string
2222
default: ""
23+
pytorch_mode:
24+
description: PyTorch mode, source or wheels
25+
type: string
26+
default: "source"
2327
python_version:
2428
description: Python version
2529
type: string
@@ -96,6 +100,7 @@ jobs:
96100
with:
97101
repository: pytorch/pytorch
98102
ref: ${{ inputs.pytorch_ref }}
103+
mode: ${{ inputs.pytorch_mode }}
99104

100105
- name: Install pass_rate dependencies
101106
run: |

.github/workflows/build-test.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ on:
1212
description: PyTorch ref, keep empty for default
1313
type: string
1414
default: ""
15+
pytorch_mode:
16+
description: PyTorch mode, source or wheels
17+
type: string
18+
default: "source"
1519
upload_test_reports:
1620
description: Upload test reports
1721
type: boolean
@@ -120,6 +124,7 @@ jobs:
120124
driver_version: ${{ matrix.driver }}
121125
runner_label: ${{ inputs.runner_label }}
122126
pytorch_ref: ${{ inputs.pytorch_ref }}
127+
pytorch_mode: ${{ inputs.pytorch_mode }}
123128
python_version: ${{ matrix.python }}
124129
upload_test_reports: ${{ inputs.upload_test_reports || false }}
125130
ignore_errors: ${{ inputs.ignore_errors || false }}

0 commit comments

Comments
 (0)