|
56 | 56 | sympy_product, |
57 | 57 | sympy_subs, |
58 | 58 | triton_type, |
| 59 | + triton_version_uses_attrs_dict, |
59 | 60 | upcast_compute_type, |
60 | 61 | ) |
61 | 62 | from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V |
62 | 63 | from ..wrapper_benchmark import get_kernel_category_by_source_code |
63 | 64 | from .block_analysis import BlockPatternMatcher |
64 | 65 | from .common import ( |
65 | 66 | BackendFeature, |
| 67 | + ConstexprArg, |
66 | 68 | CSE, |
67 | 69 | CSEVariable, |
68 | 70 | DeferredLine, |
|
85 | 87 | ) |
86 | 88 | from .triton_utils import ( |
87 | 89 | config_of, |
| 90 | + non_constexpr_signature, |
88 | 91 | should_unwrap_unspec_arg, |
89 | | - signature_of, |
90 | 92 | signature_to_meta, |
91 | 93 | ) |
92 | 94 |
|
@@ -3357,6 +3359,35 @@ def codegen_kernel(self, name=None): |
3357 | 3359 |
|
3358 | 3360 | mutated_args = sorted(mutated_args) |
3359 | 3361 |
|
| 3362 | + for tree in self.active_range_trees(): |
| 3363 | + sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) |
| 3364 | + signature.append(sizearg) |
| 3365 | + argdefs.append(sizearg.name) |
| 3366 | + # constexpr version causes issues, see |
| 3367 | + # https://github.com/pytorch/torchdynamo/pull/1362 |
| 3368 | + # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( |
| 3369 | + # tree.numel |
| 3370 | + # ) |
| 3371 | + # argdefs.append(f"{tree.prefix}numel: tl.constexpr") |
| 3372 | + |
| 3373 | + def add_constexpr_arg(arg_name): |
| 3374 | + # new versions (but not old versions) of Triton need constexprs included in the signature |
| 3375 | + if triton_version_uses_attrs_dict(): |
| 3376 | + signature.append(ConstexprArg(arg_name)) |
| 3377 | + argdefs.append(f"{arg_name} : tl.constexpr") |
| 3378 | + |
| 3379 | + for tree in self.range_trees: |
| 3380 | + if tree.is_reduction and self.persistent_reduction: |
| 3381 | + # Rn_BLOCK for persistent_reduction is defined in codegen_static_numels |
| 3382 | + continue |
| 3383 | + if tree.tensor_dim is None: |
| 3384 | + continue |
| 3385 | + |
| 3386 | + add_constexpr_arg(f"{tree.prefix.upper()}BLOCK") |
| 3387 | + |
| 3388 | + if self.cooperative_reduction: |
| 3389 | + add_constexpr_arg("RSPLIT") |
| 3390 | + |
3360 | 3391 | triton_meta_signature = signature_to_meta( |
3361 | 3392 | signature, size_dtype=self.index_dtype, argdefs=argdefs |
3362 | 3393 | ) |
@@ -3390,42 +3421,19 @@ def codegen_kernel(self, name=None): |
3390 | 3421 | num_gb = self.estimate_kernel_num_bytes() / 1e9 |
3391 | 3422 | inductor_meta["kernel_num_gb"] = num_gb |
3392 | 3423 |
|
3393 | | - for tree in self.active_range_trees(): |
3394 | | - sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) |
3395 | | - signature.append(sizearg) |
3396 | | - triton_meta_signature[sizearg.name] = signature_of( |
3397 | | - sizearg, size_dtype=self.index_dtype |
3398 | | - ) |
3399 | | - argdefs.append(f"{tree.prefix}numel") |
3400 | | - # constexpr version causes issues, see |
3401 | | - # https://github.com/pytorch/torchdynamo/pull/1362 |
3402 | | - # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( |
3403 | | - # tree.numel |
3404 | | - # ) |
3405 | | - # argdefs.append(f"{tree.prefix}numel: tl.constexpr") |
3406 | 3424 | triton_meta["configs"] = [config_of(signature)] |
3407 | 3425 |
|
3408 | | - # Triton compiler includes equal_to_1 args into constants even |
3409 | | - # when they are not constexpr. otherwise there may be a segfault |
3410 | | - # during launching the Inductor-compiled Triton kernel. |
3411 | | - # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 |
3412 | | - # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 |
3413 | | - for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] |
3414 | | - triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] |
| 3426 | + if not triton_version_uses_attrs_dict(): |
| 3427 | + # Triton compiler includes equal_to_1 args into constants even |
| 3428 | + # when they are not constexpr. otherwise there may be a segfault |
| 3429 | + # during launching the Inductor-compiled Triton kernel. |
| 3430 | + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 |
| 3431 | + # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 |
| 3432 | + for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] |
| 3433 | + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] |
3415 | 3434 |
|
3416 | 3435 | self.triton_meta = triton_meta |
3417 | 3436 |
|
3418 | | - for tree in self.range_trees: |
3419 | | - if tree.is_reduction and self.persistent_reduction: |
3420 | | - # Rn_BLOCK for persistent_reduction is defined in codegen_static_numels |
3421 | | - continue |
3422 | | - if tree.tensor_dim is None: |
3423 | | - continue |
3424 | | - argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr") |
3425 | | - |
3426 | | - if self.cooperative_reduction: |
3427 | | - argdefs.append("RSPLIT : tl.constexpr") |
3428 | | - |
3429 | 3437 | self.codegen_body() |
3430 | 3438 |
|
3431 | 3439 | for helper in self.helper_functions: |
@@ -3457,7 +3465,9 @@ def codegen_kernel(self, name=None): |
3457 | 3465 | else: |
3458 | 3466 | tile_hint = "" |
3459 | 3467 | if len(size_hints) == 2: |
3460 | | - if len(signature) == 4: # input, output and 2 args |
| 3468 | + if ( |
| 3469 | + len(non_constexpr_signature(signature)) == 4 |
| 3470 | + ): # input, output and 2 args |
3461 | 3471 | tile_hint = "tile_hint=TileHint.SQUARE," |
3462 | 3472 | else: |
3463 | 3473 | tile_hint = "tile_hint=TileHint.DEFAULT," |
|
0 commit comments