Skip to content

Commit c74534d

Browse files
authored
Fix PyTorch inductor tests after #3112 (#3126)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 7a7ac12 commit c74534d

File tree

2 files changed

+258
-41
lines changed

2 files changed

+258
-41
lines changed

scripts/patch-pytorch.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ echo "Applying PyTorch patches in $REPO_ROOT"
1717
cd "$REPO_ROOT"
1818

1919
curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -
20-
# build trigger
20+
# build trigger #3126
2121
git apply "${SCRIPT_DIR}/pytorch.patch"

scripts/pytorch.patch

Lines changed: 257 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,165 @@
11
diff --git a/test/inductor/test_codegen_triton.py b/test/inductor/test_codegen_triton.py
2-
index 84264bf1b01..aa9a624d5ac 100644
2+
index 84264bf1b0..54b6d028cf 100644
33
--- a/test/inductor/test_codegen_triton.py
44
+++ b/test/inductor/test_codegen_triton.py
5-
@@ -48,7 +48,7 @@ class TestCodegenTriton(InductorTestCase):
6-
return config.divisible_by_16
5+
@@ -39,45 +39,31 @@ class TestCodegenTriton(InductorTestCase):
6+
s0 = sympy.Symbol("s0", positive=True, integer=True)
7+
s1 = sympy.Symbol("s1", positive=True, integer=True)
78

8-
self.assertEqual(
9+
- def _check_divisibility(config):
10+
- try:
11+
- from triton.backends.compiler import AttrsDescriptor # noqa: F401
12+
-
13+
- return config.divisibility_16
14+
- except ImportError:
15+
- return config.divisible_by_16
16+
-
17+
- self.assertEqual(
918
- (2,),
10-
+ [(2,)],
11-
_check_divisibility(
12-
triton_utils.config_of(
13-
[
14-
@@ -63,7 +63,7 @@ class TestCodegenTriton(InductorTestCase):
19+
- _check_divisibility(
20+
- triton_utils.config_of(
21+
- [
22+
- SizeArg("A", two), # no
23+
- SizeArg("B", eight), # no
24+
- SizeArg("C", sixteen), # yes
25+
- SizeArg("D", s0), # no
26+
- SizeArg("E", s1), # no
27+
- ]
28+
- )
29+
- ),
30+
+ config = triton_utils.config_of(
31+
+ [
32+
+ SizeArg("A", two), # no
33+
+ SizeArg("B", eight), # no
34+
+ SizeArg("C", sixteen), # yes
35+
+ SizeArg("D", s0), # no
36+
+ SizeArg("E", s1), # no
37+
+ ]
1538
)
16-
17-
self.assertEqual(
39+
-
40+
- self.assertEqual(
1841
- (0, 2, 4, 5, 6),
19-
+ [(0,), (2,), (4,), (5,), (6,)],
20-
_check_divisibility(
21-
triton_utils.config_of(
22-
[
42+
- _check_divisibility(
43+
- triton_utils.config_of(
44+
- [
45+
- SizeArg("A", two * eight), # 0: yes
46+
- SizeArg("B", eight * s0), # 1: no
47+
- SizeArg("C", two * eight * s0), # 2: yes
48+
- SizeArg("D", s0 * s1), # 3: no
49+
- SizeArg("E", sixteen * s0), # 4: yes
50+
- SizeArg("F", sixteen * eight * s0 * s1), # 5: yes
51+
- SizeArg("G", two * eight * s0 * s1), # 6: yes
52+
- ]
53+
- )
54+
- ),
55+
+ # check for key
56+
+ config[((2,),)]
57+
+
58+
+ config = triton_utils.config_of(
59+
+ [
60+
+ SizeArg("A", two * eight), # 0: yes
61+
+ SizeArg("B", eight * s0), # 1: no
62+
+ SizeArg("C", two * eight * s0), # 2: yes
63+
+ SizeArg("D", s0 * s1), # 3: no
64+
+ SizeArg("E", sixteen * s0), # 4: yes
65+
+ SizeArg("F", sixteen * eight * s0 * s1), # 5: yes
66+
+ SizeArg("G", two * eight * s0 * s1), # 6: yes
67+
+ ]
68+
)
69+
+ # check for key
70+
+ config[((0,), (2,), (4,), (5,), (6,))]
71+
72+
73+
if __name__ == "__main__":
2374
diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py
24-
index 4d7a85029e3..f3d45ea5520 100644
75+
index f674262ab6..fc2b8a6c27 100644
2576
--- a/test/inductor/test_triton_kernels.py
2677
+++ b/test/inductor/test_triton_kernels.py
27-
@@ -1268,9 +1268,9 @@ def forward(self, x_1, output_1):
28-
if dynamic:
29-
# when half_n_elements passed to the Triton kernel is
30-
# dynamic, equal_to_1 specializaiton can't be enforced
78+
@@ -55,14 +55,6 @@ if HAS_GPU:
79+
fast_dividef as my_fast_dividef,
80+
)
81+
82+
- def _triton_get_ast_equal_to_str(params):
83+
- try:
84+
- from triton.backends.compiler import AttrsDescriptor # noqa: F401
85+
-
86+
- return f"'tt.equal_to': {params}"
87+
- except ImportError:
88+
- return f"equal_to_1={params}"
89+
-
90+
# Define shared triton constants here.
91+
CONSTANT_C: tl.constexpr = 4
92+
STRING_CONSTANT_C: tl.constexpr = "CONSTANT_C"
93+
@@ -1266,12 +1258,6 @@ def forward(self, x_1, output_1):
94+
torch.compile(f, dynamic=dynamic), x, y
95+
)
96+
97+
- if dynamic:
98+
- # when half_n_elements passed to the Triton kernel is
99+
- # dynamic, equal_to_1 specializaiton can't be enforced
31100
- self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
32-
+ self.assertTrue(_triton_get_ast_equal_to_str([]) in sources[0])
33-
else:
101+
- else:
34102
- self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0])
35-
+ self.assertTrue(_triton_get_ast_equal_to_str([(3,)]) in sources[0])
36103
self.assertEqual(compiled_out, eager_out)
37104

38105
@requires_gpu
39-
@@ -1299,7 +1299,7 @@ def forward(self, x_1, output_1):
106+
@@ -1298,9 +1284,6 @@ def forward(self, x_1, output_1):
107+
torch.compile(f, dynamic=dynamic), x, y
108+
)
40109

41-
# float 1.0 (both literal or symbolic)
42-
# should not be added to equal_to_1
110+
- # float 1.0 (both literal or symbolic)
111+
- # should not be added to equal_to_1
43112
- self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
44-
+ self.assertTrue(_triton_get_ast_equal_to_str([]) in sources[0])
45113
self.assertEqual(compiled_out, eager_out)
46114

47115
@requires_gpu
48116
diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py
49-
index c3f72bc5215..03aab72dca9 100644
117+
index c3f72bc521..016e30790a 100644
50118
--- a/torch/_higher_order_ops/triton_kernel_wrap.py
51119
+++ b/torch/_higher_order_ops/triton_kernel_wrap.py
52-
@@ -239,7 +239,7 @@ def generate_ttir(
120+
@@ -184,6 +184,7 @@ def generate_ttir(
121+
from triton.compiler.compiler import ASTSource
122+
from triton.runtime.autotuner import Autotuner
123+
from triton.runtime.jit import JITFunction
124+
+ from triton._utils import find_paths_if, get_iterable_path
125+
126+
import torch._inductor.ir
127+
from torch._subclasses.fake_tensor import FakeTensor
128+
@@ -233,24 +234,40 @@ def generate_ttir(
129+
name for name, arg in ordered_args.items() if isinstance(arg, Tensor)
130+
]
131+
132+
- def _get_specialization(args): # type: ignore[no-untyped-def]
133+
+ def _get_specialization(kernel, *args): # type: ignore[no-untyped-def]
134+
try:
135+
- from triton.backends.compiler import AttrsDescriptor # noqa: F401
136+
+ from triton.runtime.jit import create_function_from_signature
53137

54138
target = triton.runtime.driver.active.get_current_target()
55139
backend = triton.compiler.compiler.make_backend(target)
56140
- return backend.get_attrs_descriptor(args, kernel.params)
57-
+ return backend.get_attrs_descriptor(kernel.params, args)
141+
+ # from: binder = create_function_from_signature(self.signature, self.params, backend)
142+
+ specialization = []
143+
+ # signature
144+
+ for name, kp in zip(kernel.signature.parameters.keys(), kernel.params):
145+
+ if kp.is_constexpr:
146+
+ specialization.append(f'("constexpr", {name})')
147+
+ else:
148+
+ is_const = 'True' if kp.is_const else 'False'
149+
+ specialize = 'False' if kp.do_not_specialize else 'True'
150+
+ align = 'False' if kp.do_not_specialize_on_alignment else 'True'
151+
+ ret = f"specialize_impl({name}, specialize_extra, {is_const}, {specialize}, {align})"
152+
+ if kp.annotation_type:
153+
+ specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
154+
+ else:
155+
+ specialization.append(f"{ret}")
156+
+ return specialization
58157
except ImportError:
59158
return kernel._get_config(*args)
60159

61-
@@ -248,9 +248,10 @@ def generate_ttir(
160+
- specialization = _get_specialization(ordered_args.values())
161+
+ specialization = _get_specialization(kernel, ordered_args.values())
162+
constants = {
62163
name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor)
63164
}
64165

@@ -71,7 +172,7 @@ index c3f72bc5215..03aab72dca9 100644
71172
for i, (name, arg) in enumerate(ordered_args.items())
72173
if i not in kernel.constexprs
73174
}
74-
@@ -258,7 +259,18 @@ def generate_ttir(
175+
@@ -258,7 +275,22 @@ def generate_ttir(
75176
triton._C.libtriton.ir.load_dialects(context)
76177
backend.load_dialects(context)
77178

@@ -87,10 +188,105 @@ index c3f72bc5215..03aab72dca9 100644
87188
+ new_signature[arg_name] = signature[arg_name]
88189
+ return new_signature
89190
+
90-
+ src = ASTSource(kernel, fixup_signature(kernel.arg_names, signature, constants), constants, specialization)
191+
+ attrvals = [x[1] for x in specialization]
192+
+ from triton._utils import find_paths_if, get_iterable_path
193+
+ attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
194+
+ attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
195+
+ src = ASTSource(kernel, fixup_signature(kernel.arg_names, signature, constants), constants, attrs)
91196

92197
# Triton changes ASTSource.make_ir to take 3/4 arguments. Handle
93198
# backward compatibility here.
199+
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
200+
index a5fe2d1119..5ea0018957 100644
201+
--- a/torch/_inductor/ir.py
202+
+++ b/torch/_inductor/ir.py
203+
@@ -5743,52 +5743,6 @@ class UserDefinedTritonKernel(ExternKernel):
204+
for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
205+
if kernel.arg_names.index(kwarg) in kernel.constexprs:
206+
constexpr_indices.append(idx)
207+
- """
208+
- Filter out None args.
209+
-
210+
- see https://github.com/pytorch/pytorch/issues/115344
211+
-
212+
- Two cases for a None arg:
213+
- 1. The arg is already tl.constexpr, so leave it in
214+
- 2. The arg is not tl.constexpr so we have to remove it
215+
- """
216+
- constexpr_indices_set = OrderedSet(constexpr_indices)
217+
- REMOVED = object()
218+
- raw_args = [
219+
- (
220+
- (idx, arg)
221+
- if (arg is not None) or (arg is None and idx in constexpr_indices_set)
222+
- else (idx, REMOVED)
223+
- )
224+
- for idx, arg in enumerate(raw_args)
225+
- ]
226+
- removed_none_args = [idx for idx, val in raw_args if val == REMOVED]
227+
- raw_args = [val for idx, val in raw_args if val != REMOVED]
228+
-
229+
- # We have to compute the constexpr indices for the new, filtered raw_args
230+
- # We also have to adjust equal_to_1.
231+
- if removed_none_args:
232+
- eq1_indices_set = OrderedSet[int](triton_meta["configs"][0].equal_to_1)
233+
- constexpr_indices = []
234+
- equal_to_1 = []
235+
- index_shift = 0
236+
- for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
237+
- # every time we encounter an idx we removed, adjust by one to account for it
238+
- # So for example if we had [None, const X]
239+
- # iter 1:
240+
- # None was removed, adjust=1
241+
- # iter 2:
242+
- # X is const at idx=1, but the adjusted idx is 0 now, because None was removed
243+
- if idx in removed_none_args:
244+
- index_shift += 1
245+
- continue
246+
- arg_index = kernel.arg_names.index(kwarg)
247+
- if arg_index in kernel.constexprs:
248+
- constexpr_indices.append(idx - index_shift)
249+
- if arg_index in eq1_indices_set:
250+
- equal_to_1.append(idx - index_shift)
251+
-
252+
- triton_meta["configs"][0].equal_to_1 = equal_to_1
253+
254+
# Call to kernel
255+
self.codegen_comment(wrapper)
256+
diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py
257+
index ccd69bf828..2c89659132 100644
258+
--- a/torch/_inductor/codegen/cpp_wrapper_gpu.py
259+
+++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py
260+
@@ -551,28 +551,8 @@ class CppWrapperGpu(CppWrapperCpu):
261+
)
262+
kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph)
263+
264+
- # args with value 1 are added into equal_to_1 and constants
265+
- # in triton_meta (in the Python codegen) which makes them
266+
- # inlined in the PTX and compiled CUBIN
267+
- arg_signatures = []
268+
- if (
269+
- triton_meta is not None
270+
- and triton_meta.get("configs")
271+
- and triton_meta.get("signature")
272+
- ):
273+
- equal_to_1 = triton_meta["configs"][0].equal_to_1
274+
- call_args = [
275+
- arg for i, arg in enumerate(call_args) if i not in equal_to_1
276+
- ]
277+
- arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1]
278+
- # extract the arg signatures from triton_meta
279+
- arg_signatures = triton_meta["signature"].values()
280+
- arg_signatures = [
281+
- v for i, v in enumerate(arg_signatures) if i not in equal_to_1
282+
- ]
283+
-
284+
call_args_str = self.generate_args_decl(
285+
- call_args, arg_types, arg_signatures
286+
+ call_args, arg_types, list(triton_meta["signature"].values())
287+
)
288+
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
289+
self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};")
94290
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
95291
index 00031a56b8d..59086d41b40 100644
96292
--- a/torch/_inductor/codegen/triton.py
@@ -173,7 +369,7 @@ index fa2a1334380..4d730fd45de 100644
173369
except ImportError:
174370
from triton.compiler.compiler import AttrsDescriptor
175371
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
176-
index 281d0e78ba4..3b059a365c9 100644
372+
index 281d0e78ba..8b3857bf2e 100644
177373
--- a/torch/_inductor/runtime/triton_heuristics.py
178374
+++ b/torch/_inductor/runtime/triton_heuristics.py
179375
@@ -414,10 +414,21 @@ class CachingAutotuner(KernelInterface):
@@ -213,19 +409,40 @@ index 281d0e78ba4..3b059a365c9 100644
213409
]
214410
binary_shared = (
215411
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
216-
@@ -952,6 +961,7 @@ class CachingAutotuner(KernelInterface):
412+
@@ -646,8 +655,11 @@ class CachingAutotuner(KernelInterface):
413+
)
414+
# reset to zero before evaluating any config
415+
self.reset_to_zero_args(*args, **kwargs)
416+
+ new_cloned_args = [*cloned_args]
417+
+ for arg_name, arg_value in launcher.config.kwargs.items():
418+
+ new_cloned_args.insert(self.fn.arg_names.index(arg_name), arg_value)
419+
launcher(
420+
- *cloned_args,
421+
+ *new_cloned_args,
422+
**cloned_kwargs,
423+
grid=grid,
424+
stream=stream,
425+
@@ -950,15 +962,21 @@ class CachingAutotuner(KernelInterface):
426+
"stream": stream,
427+
},
217428
):
429+
+ new_args = [*args]
430+
+ for arg_name, arg_value in launcher.config.kwargs.items():
431+
+ new_args.insert(self.fn.arg_names.index(arg_name), arg_value)
218432
return launcher(
219-
*args,
220-
+ **launcher.config.kwargs,
433+
- *args,
434+
+ *new_args,
221435
**kwargs,
222436
grid=grid,
223437
stream=stream,
224-
@@ -959,6 +969,7 @@ class CachingAutotuner(KernelInterface):
438+
)
225439
else:
440+
+ new_args = [*args]
441+
+ for arg_name, arg_value in launcher.config.kwargs.items():
442+
+ new_args.insert(self.fn.arg_names.index(arg_name), arg_value)
226443
return launcher(
227-
*args,
228-
+ **launcher.config.kwargs,
444+
- *args,
445+
+ *new_args,
229446
**kwargs,
230447
grid=grid,
231448
stream=stream,

0 commit comments

Comments
 (0)