Skip to content

Commit 496c07e

Browse files
[PATCH] Update flex_decoding_tensor_desc.patch (#4989)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 16c5e1f commit 496c07e

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

scripts/patch/flex_decoding_tensor_desc.patch

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
2-
index 679caa9f09..08fb95b03b 100644
2+
index 679caa9f09..4c89d1a669 100644
33
--- a/torch/_inductor/kernel/flex/flex_decoding.py
44
+++ b/torch/_inductor/kernel/flex/flex_decoding.py
5-
@@ -326,6 +326,10 @@ def create_flex_decoding_kernel(*args, **kwargs):
5+
@@ -6,6 +6,7 @@ from typing import Any
6+
import sympy
7+
8+
import torch
9+
+from torch._inductor.utils import can_use_tma
10+
from torch._inductor.virtualized import V
11+
12+
from ... import ir
13+
@@ -326,6 +327,10 @@ def create_flex_decoding_kernel(*args, **kwargs):
614
# Set default to False
715
cur_kernel_options.setdefault("USE_TMA", False)
816

0 commit comments

Comments
 (0)