Skip to content

Commit 8f1f00e

Browse files
authored
Fix Pytorch inductor tests workflow since it doesn't work due to Triton interface change in #3043 (#3080)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 82af4c7 commit 8f1f00e

File tree

2 files changed

+84
-179
lines changed

2 files changed

+84
-179
lines changed

scripts/patch-pytorch.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,4 @@ echo "Applying PyTorch patches in $REPO_ROOT"
1717
cd "$REPO_ROOT"
1818

1919
curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -
20-
# REVERT ME: it's just a trigger for pytorch rebuild
2120
git apply "${SCRIPT_DIR}/pytorch.patch"

scripts/pytorch.patch

Lines changed: 84 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/test/inductor/test_codegen_triton.py b/test/inductor/test_codegen_triton.py
2-
index 84264bf1b0..eb3fac0f39 100644
2+
index 84264bf1b01..aa9a624d5ac 100644
33
--- a/test/inductor/test_codegen_triton.py
44
+++ b/test/inductor/test_codegen_triton.py
55
@@ -48,7 +48,7 @@ class TestCodegenTriton(InductorTestCase):
@@ -16,12 +16,37 @@ index 84264bf1b0..eb3fac0f39 100644
1616

1717
self.assertEqual(
1818
- (0, 2, 4, 5, 6),
19-
+ [(0, 2, 4, 5, 6)],
19+
+ [(0,), (2,), (4,), (5,), (6,)],
2020
_check_divisibility(
2121
triton_utils.config_of(
2222
[
23+
diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py
24+
index 4d7a85029e3..f3d45ea5520 100644
25+
--- a/test/inductor/test_triton_kernels.py
26+
+++ b/test/inductor/test_triton_kernels.py
27+
@@ -1268,9 +1268,9 @@ def forward(self, x_1, output_1):
28+
if dynamic:
29+
# when half_n_elements passed to the Triton kernel is
30+
# dynamic, equal_to_1 specializaiton can't be enforced
31+
- self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
32+
+ self.assertTrue(_triton_get_ast_equal_to_str([]) in sources[0])
33+
else:
34+
- self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0])
35+
+ self.assertTrue(_triton_get_ast_equal_to_str([(3,)]) in sources[0])
36+
self.assertEqual(compiled_out, eager_out)
37+
38+
@requires_gpu
39+
@@ -1299,7 +1299,7 @@ def forward(self, x_1, output_1):
40+
41+
# float 1.0 (both literal or symbolic)
42+
# should not be added to equal_to_1
43+
- self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
44+
+ self.assertTrue(_triton_get_ast_equal_to_str([]) in sources[0])
45+
self.assertEqual(compiled_out, eager_out)
46+
47+
@requires_gpu
2348
diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py
24-
index ace56135fe..568cbde0a1 100644
49+
index ace56135fe1..7e925dd6e45 100644
2550
--- a/torch/_higher_order_ops/triton_kernel_wrap.py
2651
+++ b/torch/_higher_order_ops/triton_kernel_wrap.py
2752
@@ -238,7 +238,7 @@ def generate_ttir(
@@ -33,43 +58,41 @@ index ace56135fe..568cbde0a1 100644
3358
except ImportError:
3459
return kernel._get_config(*args)
3560

36-
@@ -251,7 +251,6 @@ def generate_ttir(
61+
@@ -247,7 +247,8 @@ def generate_ttir(
62+
name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor)
63+
}
64+
65+
- # Build kernel signature -- doesn't include constexpr arguments.
66+
+ # Build kernel signature; it should also include `constexpr` arguments but `kernel._key_of`
67+
+ # doesn't work correctly with them. They will be added in `fixup_signature` function later.
3768
signature = {
3869
name: kernel._type_of(kernel._key_of(arg))
3970
for i, (name, arg) in enumerate(ordered_args.items())
40-
- if i not in kernel.constexprs
41-
}
42-
71+
@@ -257,7 +258,18 @@ def generate_ttir(
4372
triton._C.libtriton.ir.load_dialects(context)
73+
backend.load_dialects(context)
74+
75+
- src = ASTSource(kernel, signature, constants, specialization)
76+
+ def fixup_signature(arg_names, signature, constants):
77+
+ new_signature = {arg_name: None for arg_name in arg_names}
78+
+ for arg_name in arg_names:
79+
+ if arg_name in constants and arg_name not in signature:
80+
+ # If it's not in the signature already, it's a constexpr
81+
+ # argument that we need to add in
82+
+ new_signature[arg_name] = "constexpr"
83+
+ else:
84+
+ new_signature[arg_name] = signature[arg_name]
85+
+ return new_signature
86+
+
87+
+ src = ASTSource(kernel, fixup_signature(kernel.arg_names, signature, constants), constants, specialization)
88+
89+
# Triton changes ASTSource.make_ir to take 3/4 arguments. Handle
90+
# backward compatibility here.
4491
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
45-
index 00031a56b8..b941e2aaa6 100644
92+
index 00031a56b8d..59086d41b40 100644
4693
--- a/torch/_inductor/codegen/triton.py
4794
+++ b/torch/_inductor/codegen/triton.py
48-
@@ -2980,6 +2980,7 @@ class TritonKernel(SIMDKernel):
49-
code.splice(self.imports_for_benchmark_kernel())
50-
51-
argdefs, _, signature, _ = self.args.python_argdefs()
52-
+ # breakpoint()
53-
# maps actual expression to SizeArg if it is in sizevars replacements
54-
for i, arg in enumerate(signature):
55-
if isinstance(arg, SizeArg):
56-
@@ -3030,7 +3031,7 @@ class TritonKernel(SIMDKernel):
57-
triton_meta = {
58-
"signature": triton_meta_signature,
59-
"device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
60-
- "constants": {},
61-
+ "constexprs": {},
62-
}
63-
64-
# Skip memory optimization for forward of the training loop where we expect
65-
@@ -3065,20 +3066,12 @@ class TritonKernel(SIMDKernel):
66-
argdefs.append(f"{tree.prefix}numel")
67-
# constexpr version causes issues, see
68-
# https://github.com/pytorch/torchdynamo/pull/1362
69-
- # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
70-
+ # triton_meta["constexprs"][len(argdefs)] = V.graph.sizevars.size_hint(
71-
# tree.numel
72-
# )
95+
@@ -3071,14 +3071,6 @@ class TritonKernel(SIMDKernel):
7396
# argdefs.append(f"{tree.prefix}numel: tl.constexpr")
7497
triton_meta["configs"] = [config_of(signature)]
7598

@@ -84,175 +107,58 @@ index 00031a56b8..b941e2aaa6 100644
84107
self.triton_meta = triton_meta
85108

86109
for tree in self.range_trees:
87-
@@ -3087,9 +3080,14 @@ class TritonKernel(SIMDKernel):
88-
continue
89-
if tree.tensor_dim is None:
90-
continue
91-
- argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr")
92-
+ const_name = f"{tree.prefix.upper()}BLOCK"
93-
+ triton_meta['signature'][const_name] = 'constexpr'
94-
+ triton_meta['constexprs'][const_name] = tree.numel
95-
+ argdefs.append(f"{const_name} : tl.constexpr")
96-
97-
if self.cooperative_reduction:
98-
+ triton_meta['signature']['RSPLIT'] = 'constexpr'
99-
+ triton_meta['constexprs']['RSPLIT'] = tree.numel
100-
argdefs.append("RSPLIT : tl.constexpr")
101-
102-
self.codegen_body()
103-
diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py
104-
index 8b8c29bbb1..3e5abaa824 100644
105-
--- a/torch/_inductor/codegen/triton_utils.py
106-
+++ b/torch/_inductor/codegen/triton_utils.py
107-
@@ -157,13 +157,13 @@ def config_of(
108-
raise NotImplementedError(f"unhandled {type(x)}: {x}")
109-
110-
if config.triton.divisible_by_16:
111-
- divisible_by_16 = tuple(
112-
+ divisible_by_16 = [tuple(
113-
i
114-
for i, arg in zip(indices, args)
115-
if is_aligned(arg, alignment=16, include_tensor=True)
116-
- )
117-
+ )]
118-
else:
119-
- divisible_by_16 = ()
120-
+ divisible_by_16 = []
121-
122-
equal_to_1 = tuple(
123-
i
124-
@@ -172,5 +172,7 @@ def config_of(
125-
and isinstance(arg.expr, (int, sympy.Integer))
126-
and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type]
127-
)
128-
+ if equal_to_1 != tuple():
129-
+ equal_to_1 = [equal_to_1]
130-
131-
return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
132110
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
133-
index 2ab2b32635..5f08c3c0b7 100644
111+
index 2ab2b326354..42d76b8bf94 100644
134112
--- a/torch/_inductor/codegen/wrapper.py
135113
+++ b/torch/_inductor/codegen/wrapper.py
136-
@@ -1535,16 +1535,21 @@ class PythonWrapperCodegen(CodeGen):
137-
138-
signature: List[KernelArgType] = []
139-
constants: Dict[str, Any] = {}
140-
+ constexprs = {}
141-
non_constant_indices = []
142-
equal_to_1_args: List[str] = []
143-
for idx, key in enumerate(kernel.arg_names):
144-
if key not in kwargs:
145-
+ if idx in kernel.constexprs:
146-
+ constexprs[key] = 'constexpr'
147-
continue
148-
arg = kwargs[key]
149-
if idx in kernel.constexprs:
150-
constants[key] = arg
151-
+ constexprs[key] = 'constexpr'
152-
elif kwargs[key] is None:
153-
constants[key] = None
154-
+ constexprs[key] = 'constexpr'
155-
else:
156-
non_constant_indices.append(idx)
157-
if isinstance(arg, ir.TMADescriptor):
158-
@@ -1596,9 +1601,8 @@ class PythonWrapperCodegen(CodeGen):
159-
# causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input.
160-
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
114+
@@ -1598,7 +1598,6 @@ class PythonWrapperCodegen(CodeGen):
161115
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
162-
- "constants": {
163-
+ "constexprs": {
116+
"constants": {
164117
**constants,
165118
- **dict.fromkeys(equal_to_1_args, 1),
166119
},
167120
"configs": [
168121
config_of(
169-
@@ -1607,6 +1611,8 @@ class PythonWrapperCodegen(CodeGen):
170-
)
171-
],
172-
}
173-
+ for constexpr_name in constexprs.keys():
174-
+ triton_meta['signature'][constexpr_name] = 'constexpr'
175-
176-
if restore_value_args:
177-
triton_meta["restore_value"] = tuple(restore_value_args)
178122
diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py
179-
index 276c01f3f4..4e6e1ab9ce 100644
123+
index 276c01f3f42..5c633b7963b 100644
180124
--- a/torch/_inductor/runtime/hints.py
181125
+++ b/torch/_inductor/runtime/hints.py
182-
@@ -53,6 +53,7 @@ if _is_triton_available():
126+
@@ -48,8 +48,8 @@ if _is_triton_available():
127+
):
128+
# Prepare the arguments for AttrsDescriptor
129+
kwargs = {
130+
- "tt.divisibility": divisible_by_16,
131+
- "tt.equal_to": equal_to_1,
132+
+ "tt.divisibility": tuple([(i,) for i in divisible_by_16]),
133+
+ "tt.equal_to": tuple([(i,) for i in equal_to_1]),
183134
}
184135

185136
# Instantiate AttrsDescriptor with the prepared arguments
186-
+ # breakpoint()
187-
res = AttrsDescriptor.from_dict(
188-
{"arg_properties": kwargs, "cls": AttrsDescriptor.__name__}
189-
)
190137
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
191-
index af8530e94d..a1935831e2 100644
138+
index af8530e94d0..1ec44de9806 100644
192139
--- a/torch/_inductor/runtime/triton_heuristics.py
193140
+++ b/torch/_inductor/runtime/triton_heuristics.py
194-
@@ -407,6 +407,7 @@ class CachingAutotuner(KernelInterface):
195-
196-
def _precompile_config(self, cfg: Config, warm_cache_only: bool):
197-
"""Ahead of time compile a given autotuner config."""
198-
+ # print(f"self.triton_meta: {self.triton_meta}")
199-
compile_meta = copy.deepcopy(self.triton_meta)
200-
for k, v in cfg.kwargs.items():
201-
if self.device_props.type == "hip":
202-
@@ -419,7 +420,7 @@ class CachingAutotuner(KernelInterface):
203-
if k == "kpack":
204-
compile_meta["kpack"] = v
205-
continue
206-
- compile_meta["constants"][k] = v
207-
+ compile_meta["constexprs"][k] = v
208-
compile_meta["num_warps"] = cfg.num_warps
209-
compile_meta["num_stages"] = cfg.num_stages
210-
compile_meta["debug"] = self.inductor_meta.get(
211-
@@ -435,12 +436,13 @@ class CachingAutotuner(KernelInterface):
141+
@@ -435,11 +435,22 @@ class CachingAutotuner(KernelInterface):
212142
else:
213143
triton_helpers.set_driver_to_gpu()
214144

215-
+ # print(compile_meta)
145+
+ def fixup_signature(arg_names, signature, constants):
146+
+ new_signature = {arg_name: None for arg_name in arg_names}
147+
+ for arg_name in arg_names:
148+
+ if arg_name in constants and arg_name not in signature:
149+
+ # If it's not in the signature already, it's a constexpr
150+
+ # argument that we need to add in
151+
+ new_signature[arg_name] = "constexpr"
152+
+ else:
153+
+ new_signature[arg_name] = signature[arg_name]
154+
+ return new_signature
155+
+
216156
if ASTSource:
217157
compile_args = (
218158
ASTSource(
219159
self.fn,
220-
compile_meta["signature"],
221-
- compile_meta["constants"],
222-
+ compile_meta["constexprs"],
160+
- compile_meta["signature"],
161+
+ fixup_signature(self.fn.arg_names, compile_meta["signature"], compile_meta["constants"]),
162+
compile_meta["constants"],
223163
compile_meta["configs"][0],
224164
),
225-
)
226-
@@ -527,7 +529,7 @@ class CachingAutotuner(KernelInterface):
227-
We also don't want to modify self.fn.
228-
229-
We know that we removed something from the signature if:
230-
- 1. It's in compile_meta["constants"]
231-
+ 1. It's in compile_meta["constexprs"]
232-
2. It isn't a constant we already know about
233-
Note: The value of interest has already been added to compile_meta['constants'],
234-
so we use self.fn.constexprs instead.
235-
@@ -538,7 +540,7 @@ class CachingAutotuner(KernelInterface):
236-
}
237-
none_args = {
238-
k
239-
- for k, v in compile_meta["constants"].items()
240-
+ for k, v in compile_meta["constexprs"].items()
241-
if v is None and k not in known_constants
242-
}
243-
none_args = none_args.difference(set(compile_meta["signature"].keys()))
244-
@@ -548,12 +550,14 @@ class CachingAutotuner(KernelInterface):
245-
for i, arg in enumerate(self.fn.arg_names)
246-
if i not in self.fn.constexprs and arg not in none_args
247-
]
248-
+ # print(f"call_args: {call_args}")
249-
250-
def_args = [
251-
name
252-
for name in self.fn.arg_names
253-
if name not in cfg.kwargs and name not in none_args
254-
]
255-
+ # print(f"def_args: {def_args}\n")
256-
binary_shared = (
257-
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
258-
)

0 commit comments

Comments
 (0)