Skip to content

Commit 1c7df8a

Browse files
committed
push_tests_mps
1 parent edb8c1b commit 1c7df8a

File tree

1 file changed

+94
-12
lines changed

1 file changed

+94
-12
lines changed

.github/workflows/push_tests_mps.yml

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,55 @@ env:
1616
HF_HUB_ENABLE_HF_TRANSFER: 1
1717
PYTEST_TIMEOUT: 600
1818
RUN_SLOW: no
19+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
20+
PYTORCH_MPS_HIGH_WATERMARK_RATIO: 0.0
1921

2022
concurrency:
2123
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
2224
cancel-in-progress: true
2325

2426
jobs:
2527
run_fast_tests_apple_m1:
26-
name: Fast PyTorch MPS tests on MacOS
27-
runs-on: macos-13-xlarge
28+
strategy:
29+
fail-fast: false
30+
matrix:
31+
config:
32+
- name: Fast Pipelines MPS tests
33+
framework: pytorch_pipelines
34+
runner: macos-13-xlarge
35+
report: torch_mps_pipelines
36+
- name: Fast Models MPS tests
37+
framework: pytorch_models
38+
runner: macos-13-xlarge
39+
report: torch_mps_models
40+
- name: Fast Schedulers MPS tests
41+
framework: pytorch_schedulers
42+
runner: macos-13-xlarge
43+
report: torch_mps_schedulers
44+
- name: Fast Others MPS tests
45+
framework: pytorch_others
46+
runner: macos-13-xlarge
47+
report: torch_mps_others
48+
- name: Fast Single File MPS tests
49+
framework: pytorch_single_file
50+
runner: macos-13-xlarge
51+
report: torch_mps_single_file
52+
- name: Fast Lora MPS tests
53+
framework: pytorch_lora
54+
runner: macos-13-xlarge
55+
report: torch_mps_lora
56+
- name: Fast Quantization MPS tests
57+
framework: pytorch_quantization
58+
runner: macos-13-xlarge
59+
report: torch_mps_quantization
60+
61+
name: ${{ matrix.config.name }}
62+
63+
runs-on: ${{ matrix.config.runner }}
64+
65+
defaults:
66+
run:
67+
shell: arch -arch arm64 bash {0}
2868

2969
steps:
3070
- name: Checkout diffusers
@@ -33,7 +73,6 @@ jobs:
3373
fetch-depth: 2
3474

3575
- name: Clean checkout
36-
shell: arch -arch arm64 bash {0}
3776
run: |
3877
git clean -fxd
3978
@@ -43,10 +82,12 @@ jobs:
4382
python-version: 3.9
4483

4584
- name: Install dependencies
46-
shell: arch -arch arm64 bash {0}
4785
run: |
4886
${CONDA_RUN} python -m pip install --upgrade pip uv
4987
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
88+
${CONDA_RUN} python -m uv pip install hf_transfer
89+
${CONDA_RUN} python -m uv pip install peft
90+
${CONDA_RUN} python -m uv pip install bitsandbytes
5091
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
5192
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
5293
${CONDA_RUN} python -m uv pip install transformers --upgrade
@@ -56,21 +97,62 @@ jobs:
5697
run: |
5798
${CONDA_RUN} python utils/print_env.py
5899
59-
- name: Run fast PyTorch tests on M1 (MPS)
60-
shell: arch -arch arm64 bash {0}
61-
env:
62-
HF_HOME: /System/Volumes/Data/mnt/cache
63-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
100+
- name: Run fast PyTorch Pipeline MPS tests
101+
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
102+
run: |
103+
${CONDA_RUN} python -m pytest -n 0 -s -v -k "not Flax and not Onnx" \
104+
--make-reports=tests_${{ matrix.config.report }} \
105+
tests/pipelines/
106+
107+
- name: Run fast PyTorch Models MPS tests
108+
if: ${{ matrix.config.framework == 'pytorch_models' }}
109+
run: |
110+
${CONDA_RUN} python -m pytest -n 0 -s -v -k "not Flax and not Onnx and not Dependency" \
111+
--make-reports=tests_${{ matrix.config.report }} \
112+
tests/models/
113+
114+
- name: Run fast PyTorch Schedulers MPS tests
115+
if: ${{ matrix.config.framework == 'pytorch_schedulers' }}
116+
run: |
117+
${CONDA_RUN} python -m pytest -n 0 -s -v -k "not Flax and not Onnx and not Dependency" \
118+
--make-reports=tests_${{ matrix.config.report }} \
119+
tests/schedulers/
120+
121+
- name: Run fast PyTorch Others MPS tests
122+
if: ${{ matrix.config.framework == 'pytorch_others' }}
123+
run: |
124+
${CONDA_RUN} python -m pytest -n 0 -s -v \
125+
--make-reports=tests_${{ matrix.config.report }} \
126+
tests/others/
127+
128+
- name: Run fast PyTorch Single File MPS tests
129+
if: ${{ matrix.config.framework == 'pytorch_single_file' }}
130+
run: |
131+
${CONDA_RUN} python -m pytest -n 0 -s -v \
132+
--make-reports=tests_${{ matrix.config.report }} \
133+
tests/single_file/
134+
135+
- name: Run fast PyTorch Lora MPS tests
136+
if: ${{ matrix.config.framework == 'pytorch_lora' }}
137+
run: |
138+
${CONDA_RUN} python -m pytest -n 0 -s -v \
139+
--make-reports=tests_${{ matrix.config.report }} \
140+
tests/lora/
141+
142+
- name: Run fast PyTorch Quantization MPS tests
143+
if: ${{ matrix.config.framework == 'pytorch_quantization' }}
64144
run: |
65-
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
145+
${CONDA_RUN} python -m pytest -n 0 -s -v \
146+
--make-reports=tests_${{ matrix.config.report }} \
147+
tests/quantization/
66148
67149
- name: Failure short reports
68150
if: ${{ failure() }}
69-
run: cat reports/tests_torch_mps_failures_short.txt
151+
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
70152

71153
- name: Test suite reports artifacts
72154
if: ${{ always() }}
73155
uses: actions/upload-artifact@v4
74156
with:
75-
name: pr_torch_mps_test_reports
157+
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
76158
path: reports

0 commit comments

Comments
 (0)