Skip to content

Commit 4a53fb0

Browse files
authored
[Flex Attention] Apply patch from pytorch#143553 instead of using fork (#3945)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 09e0a29 commit 4a53fb0

File tree

4 files changed

+4
-18
lines changed

4 files changed

+4
-18
lines changed

.github/actions/setup-pytorch/action.yml

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,8 @@ runs:
4545
if: inputs.ref != ''
4646
shell: bash
4747
run: |
48-
if [[ "${{ inputs.repository }}" = "liangan1/pytorch" ]]; then
49-
PYTORCH_COMMIT_ID="$(<.github/pins/pytorchFlexAttention.txt)"
50-
echo "PYTORCH_REPO=${{ inputs.repository }}" | tee -a "$GITHUB_ENV"
51-
echo "PYTORCH_COMMIT_ID=$PYTORCH_COMMIT_ID" | tee -a "$GITHUB_ENV"
52-
else
53-
echo "PYTORCH_REPO=${{ inputs.repository }}" | tee -a "$GITHUB_ENV"
54-
echo "PYTORCH_COMMIT_ID=${{ steps.commit-id.outputs.commit_id }}" | tee -a "$GITHUB_ENV"
55-
fi
48+
echo "PYTORCH_REPO=${{ inputs.repository }}" | tee -a "$GITHUB_ENV"
49+
echo "PYTORCH_COMMIT_ID=${{ steps.commit-id.outputs.commit_id }}" | tee -a "$GITHUB_ENV"
5650
5751
- name: Identify Python version
5852
shell: bash
@@ -105,7 +99,7 @@ runs:
10599
path: pytorch
106100

107101
- name: Apply additional PR patches
108-
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.mode == 'source' && (inputs.repository == 'pytorch/pytorch' || inputs.repository == 'liangan1/pytorch') }}
102+
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' && inputs.mode == 'source' }}
109103
shell: bash
110104
run: |
111105
cd pytorch

.github/pins/pytorchFlexAttention.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

.github/workflows/triton-benchmarks.yml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,6 @@ jobs:
274274
cd benchmarks/micro_benchmarks
275275
python run_benchmarks.py --reports $REPORTS
276276
277-
# Install Pytorch with FlexAttention XPU support enabled
278-
- name: Setup PyTorch
279-
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
280-
uses: ./.github/actions/setup-pytorch
281-
with:
282-
repository: liangan1/pytorch
283-
ref: liangan1/flex_attention
284-
285277
- name: Run Triton FlexAttention Causal Mask fwd kernel benchmark
286278
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_causal_mask.py') }}
287279
run: |

scripts/patch-pytorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ cd "$REPO_ROOT"
1818

1919
# curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -
2020
git apply "${SCRIPT_DIR}/pytorch_fp64.patch"
21+
curl -sSL https://github.com/pytorch/pytorch/pull/143553.diff | git apply --exclude=test/inductor/test_flex_attention.py --exclude=test/inductor/test_flex_decoding.py -

0 commit comments

Comments
 (0)