@@ -110,6 +110,24 @@ index 00031a56b8d..59086d41b40 100644
110110 self.triton_meta = triton_meta
111111
112112 for tree in self.range_trees:
113+ diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py
114+ index 8b8c29bbb15..c89a76e9868 100644
115+ --- a/torch/_inductor/codegen/triton_utils.py
116+ +++ b/torch/_inductor/codegen/triton_utils.py
117+ @@ -165,12 +165,4 @@ def config_of(
118+ else:
119+ divisible_by_16 = ()
120+
121+ - equal_to_1 = tuple(
122+ - i
123+ - for i, arg in zip(indices, args)
124+ - if isinstance(arg, SizeArg)
125+ - and isinstance(arg.expr, (int, sympy.Integer))
126+ - and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type]
127+ - )
128+ -
129+ - return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
130+ + return AttrsDescriptorWrapper(divisible_by_16)
113131diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
114132index 2ab2b326354..42d76b8bf94 100644
115133--- a/torch/_inductor/codegen/wrapper.py
@@ -123,25 +141,24 @@ index 2ab2b326354..42d76b8bf94 100644
123141 "configs": [
124142 config_of(
125143diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py
126- index fa2a1334380..11de4dc6e5f 100644
144+ index fa2a1334380..4d730fd45de 100644
127145--- a/torch/_inductor/runtime/hints.py
128146+++ b/torch/_inductor/runtime/hints.py
129- @@ -44,25 +44,16 @@ def _is_triton_available() -> bool:
147+ @@ -44,25 +44,14 @@ def _is_triton_available() -> bool:
130148 # Define `AttrsDescriptorWrapper` function with clear conditional handling
131149 if _is_triton_available():
132150 try:
133151- from triton.backends.compiler import AttrsDescriptor
134152
135153 def AttrsDescriptorWrapper(
136154 divisible_by_16=None,
137- equal_to_1=None,
155+ - equal_to_1=None,
138156 ):
139157- # Prepare the arguments for AttrsDescriptor
140158 kwargs = {
141159- "tt.divisibility": divisible_by_16,
142160- "tt.equal_to": equal_to_1,
143161+ tuple([(i,) for i in divisible_by_16]): [["tt.divisibility", 16]],
144- + tuple([(i,) for i in equal_to_1]): [["tt.equal_to", 1]],
145162 }
146163-
147164- # Instantiate AttrsDescriptor with the prepared arguments
@@ -156,7 +173,7 @@ index fa2a1334380..11de4dc6e5f 100644
156173 except ImportError:
157174 from triton.compiler.compiler import AttrsDescriptor
158175diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
159- index 281d0e78ba4..8e7ca6b5822 100644
176+ index 281d0e78ba4..3b059a365c9 100644
160177--- a/torch/_inductor/runtime/triton_heuristics.py
161178+++ b/torch/_inductor/runtime/triton_heuristics.py
162179@@ -414,10 +414,21 @@ class CachingAutotuner(KernelInterface):
@@ -196,3 +213,19 @@ index 281d0e78ba4..8e7ca6b5822 100644
196213 ]
197214 binary_shared = (
198215 binary.shared if hasattr(binary, "shared") else binary.metadata.shared
216+ @@ -952,6 +961,7 @@ class CachingAutotuner(KernelInterface):
217+ ):
218+ return launcher(
219+ *args,
220+ + **launcher.config.kwargs,
221+ **kwargs,
222+ grid=grid,
223+ stream=stream,
224+ @@ -959,6 +969,7 @@ class CachingAutotuner(KernelInterface):
225+ else:
226+ return launcher(
227+ *args,
228+ + **launcher.config.kwargs,
229+ **kwargs,
230+ grid=grid,
231+ stream=stream,
0 commit comments