Skip to content

Commit 386e419

Browse files
[PATCH] Update use_tma.patch (#5053)
The patch can be reduced due to pytorch/pytorch#162138. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 659d012 commit 386e419

File tree

2 files changed

+1
-14
lines changed

2 files changed

+1
-14
lines changed

.github/actions/setup-pytorch/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ runs:
8282
uses: ./.github/actions/load
8383
env:
8484
# Increase this value to reset cache
85-
CACHE_NUMBER: 16
85+
CACHE_NUMBER: 17
8686
with:
8787
path: pytorch
8888
key: pytorch-$PYTORCH_CACHE_KEY-$CACHE_NUMBER

scripts/patch/use_tma.patch

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,3 @@ index 91ba941da0..a6b87212ad 100644
3434
# Add ROCm-specific parameters if they exist in the config
3535
for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]:
3636
if hasattr(conf, attrib):
37-
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
38-
index 0876f99307..4fa1c87560 100644
39-
--- a/torch/_inductor/utils.py
40-
+++ b/torch/_inductor/utils.py
41-
@@ -1696,7 +1696,7 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
42-
strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
43-
44-
# Every logical size ≥ 2
45-
- if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i):
46-
+ if not torch.xpu.is_available() and any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i):
47-
return False
48-
49-
# Find the single contiguous (“inner”) dim

0 commit comments

Comments
 (0)