Skip to content

Commit 92f41cc

Browse files
blaine-risterpytorchmergebot
authored andcommitted
[Inductor] Support precomputed size args in the FX backend. (pytorch#157758)
# Feature If a Triton kernel has a complicated indexing expression, Inductor may decide to precompute it on the host and pass it to the kernel as an argument. This happens in situations like broadcasts with dynamic shapes. This PR adds support for this feature to Inductor's FX IR backend. We generate FX IR for precomputed size args in 3 steps: 1. In `PythonWrapperCodegen`, this PR refactors the relevant code to use a `SymbolicCallArgLine` instead of raw Python strings. This stores a (symbol, expr) pair. (Prior to this PR, it was (str, expr), but changing this to a symbol makes it easier to do substitutions later on.) 2. In `WrapperFxCodegen`, keep a dict of {symbol: expr} arg defs which gets updated whenever we see a `SymbolicCallArgLine`. 3. When the FX backend sees a `KernelCallLine`, it uses this dict to replace symbolic call args with their definitions. In the longer run, it might be desirable to emit FX nodes defining these symbolic call args. That way, we could reuse the size computation when the same kernel is called multiple times. However, I wasn't sure if there was an existing way to generate FX nodes from a sympy expression, and implementing that seemed like overkill for the present purposes. # Test plan Added a new CI test exercising this feature. Pull Request resolved: pytorch#157758 Approved by: https://github.com/jansel
1 parent 95bc3da commit 92f41cc

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

test/inductor/test_fxir_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,22 @@ def get_input():
393393
]
394394
self.assertEqual(placeholder.meta["val"], symbol)
395395

396+
def test_dynamic_shapes_precomputed_size(self):
397+
"""
398+
Test dynamic shapes where a kernel's size arg is precomputed.
399+
"""
400+
func = torch.add
401+
args = [
402+
torch.randn(shape, device=self.device) for shape in [(7, 12, 9), (7, 1, 1)]
403+
]
404+
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
405+
406+
# Check for the precomputed size arg.
407+
(triton_node,) = gm.graph.find_nodes(
408+
op="call_function", target=triton_kernel_wrapper_mutation
409+
)
410+
self.assertIn("ks0", triton_node.kwargs["kwargs"])
411+
396412
@config.patch({"trace.enabled": True})
397413
@unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code")
398414
def test_debug(self, mock_output_code):

torch/_inductor/codegen/wrapper.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def traverse(cur_kernel):
322322

323323
@dataclasses.dataclass
324324
class SymbolicCallArg:
325-
inner: str
325+
inner: sympy.Symbol
326326
# the original symbolic expression represented by inner
327327
inner_expr: sympy.Expr
328328

@@ -1726,7 +1726,8 @@ def ensure_size_computed(self, sym: sympy.Symbol):
17261726
return
17271727
self.computed_sizes.add(sym)
17281728
expr = V.graph.sizevars.inv_precomputed_replacements[sym]
1729-
self.writeline(f"{sym} = {pexpr(expr)}")
1729+
arg = SymbolicCallArg(sym, expr)
1730+
self.writeline(SymbolicCallArgLine(self, arg, V.graph))
17301731

17311732
def finalize_prefix(self):
17321733
pass
@@ -2257,9 +2258,10 @@ def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr:
22572258
return name, triton_meta, extra_launcher_call_args
22582259

22592260
def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None):
2260-
expr = f"{kernel_name}_{tree.prefix}numel"
2261+
sym_name = f"{kernel_name}_{tree.prefix}numel"
22612262
if suffix is not None:
2262-
expr += f"_{suffix}"
2263+
sym_name += f"_{suffix}"
2264+
sym = sympy.Symbol(sym_name, is_integer=True, is_positive=True)
22632265

22642266
# We can get symbolic expressions here, like s0*64
22652267
# It is fine to have them here, but we need to handle them correctly as their own type
@@ -2268,7 +2270,7 @@ def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = No
22682270
# This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for
22692271
# constant now, need type info. I agree, this needs type info, and while this is not true type info
22702272
# it suffices as a type hint for the purposes of producing the correct code for this type.
2271-
arg = SymbolicCallArg(expr, tree.numel)
2273+
arg = SymbolicCallArg(sym, tree.numel)
22722274
self.writeline(SymbolicCallArgLine(self, arg, V.graph))
22732275

22742276
return arg

torch/_inductor/codegen/wrapper_fxir.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch._inductor.codecache import PyCodeCache
1818
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
1919
from torch._inductor.select_algorithm import extern_kernels # noqa: F401
20-
from torch._inductor.utils import sympy_product
20+
from torch._inductor.utils import sympy_product, sympy_subs
2121
from torch._inductor.virtualized import V
2222
from torch._library.triton import wrap_triton
2323
from torch.fx import GraphModule
@@ -155,6 +155,9 @@ def __post_init__(self) -> None:
155155
Optional[str], torch.fx.Node
156156
] = {} # Symbol table for codegen.
157157
self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels.
158+
self.symbolic_arg_defs: dict[
159+
sympy.Symbol, sympy.Expr
160+
] = {} # Call arg definitions.
158161
self._unique_symbol_ids: Counter[str] = Counter()
159162

160163
def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner:
@@ -576,12 +579,15 @@ def replace_floor_div(expr: sympy.Expr) -> sympy.Expr:
576579
else:
577580
return sympy.floor(expr)
578581

579-
def expr_to_symint(expr: Union[int, sympy.Expr]) -> Union[int, sympy.Expr]:
580-
return (
581-
convert_to_symint(expr.replace(sympy.floor, replace_floor_div))
582-
if isinstance(expr, sympy.Expr)
583-
else expr
584-
)
582+
def expr_to_symint(
583+
expr: Union[int, torch.fx.Node, sympy.Expr],
584+
) -> Union[int, torch.fx.Node, sympy.Expr]:
585+
if not isinstance(expr, sympy.Expr):
586+
return expr
587+
588+
expr = expr.replace(sympy.floor, replace_floor_div)
589+
expr = sympy_subs(expr, self.symbolic_arg_defs)
590+
return convert_to_symint(expr)
585591

586592
# Convert sympy expressions to symints.
587593
# Use FloorDiv over sympy.floor, so we can get nicer Python code from FX.
@@ -691,4 +697,6 @@ def _generate_kernel_definition(self, line: WrapperLine) -> None:
691697

692698
def _generate_symbolic_call_arg(self, line: WrapperLine) -> None:
693699
assert isinstance(line, SymbolicCallArgLine)
694-
# No need for an FX node, as we will pass the arg to kernels via a SymInt.
700+
# Store the arg: expr mapping for later use.
701+
arg = line.arg
702+
self.symbolic_arg_defs[arg.inner] = arg.inner_expr

0 commit comments

Comments
 (0)