Skip to content

Commit f68de58

Browse files
blaine-risterpytorchmergebot
authored andcommitted
[Inductor-FX] Support symbol and dynamic scalar graph inputs and outputs (pytorch#163596)
# Problems This PR fixes a few edge cases that the FX converter missed related to dynamic shapes. 1. Inductor graphs can sometimes take `sympy.Symbol` inputs. We have logic to convert these to FX placeholder nodes. However, this logic did not update the `self.expr_to_proxy` table mapping symbols to proxy nodes. (There was existing logic to do this for `ir.TensorBox` inputs, but not `sympy.Symbol`.) This caused sympy tracing to fail when these symbol inputs were used in other expressions. 2. We lacked codegen for `ShapeAsConstantBuffer`. This IR node is seen when the graph input or output is a scalar computed from dynamic shapes. # Fixes a. Update `self.expr_to_proxy` when generating placeholders for `sympy.Symbol` inputs. Change `SymbolBuffer.get_example` to convert the symbol to a `torch.SymInt`, so we can populate `meta["val"]` correctly and use the value in other computations. b. Support `ShapeAsConstantBuffer` by tracing the sympy expression. c. Move output generation inside the metadata hook, allowing us to populate `meta["val"]` for the nodes computing `ShapeAsConstantBuffer`. # Test plan Added several new CI tests: 1. `torch.cond` with dynamic shapes. This exposes both issues, as the predicate is a `ShapeAsConstantBuffer` and one of the subgraphs uses a symbol input, due to the closure. Also tests when the parent and subgraphs have different input shapes. 2. Output dynamic shape scalar. This tests `ShapeAsConstantBuffer` as an output. Pull Request resolved: pytorch#163596 Approved by: https://github.com/angelayi, https://github.com/jansel
1 parent a8e9ed2 commit f68de58

File tree

3 files changed

+81
-14
lines changed

3 files changed

+81
-14
lines changed

test/inductor/test_fxir_backend.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,48 @@ def forward(self, arg0_1, arg1_1, arg2_1):
990990
return [buf1, buf2]""", # noqa: B950
991991
)
992992

993+
@parametrize("length", (4, 8))
994+
def test_cond_dynamic_shape_pred_scalar_closure(self, length: int):
995+
"""
996+
Test cond using a predicate computed from dynamic shapes.
997+
Also test a dynamic scalar computed outside the branches.
998+
"""
999+
1000+
class M(torch.nn.Module):
1001+
def forward(self, x, y):
1002+
z = x.reshape(-1)
1003+
a = y.shape[0]
1004+
1005+
def true_fn(x):
1006+
return x + a
1007+
1008+
def false_fn(x):
1009+
return true_fn(x) / 2
1010+
1011+
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (z,))
1012+
1013+
(x, y) = [
1014+
torch.randn(shape, device=self.device)
1015+
for shape in [(length // 2,) * 2, (length,)]
1016+
]
1017+
dynamic_shapes = {
1018+
"x": {0: Dim.DYNAMIC},
1019+
"y": {0: Dim.DYNAMIC},
1020+
}
1021+
self.check(M(), (x, y), dynamic_shapes=dynamic_shapes)
1022+
1023+
def test_dynamic_scalar_output(self):
1024+
"""
1025+
Test an output scalar from dynamic shapes.
1026+
"""
1027+
1028+
class M(torch.nn.Module):
1029+
def forward(self, x):
1030+
return x.shape[0] * 3
1031+
1032+
x = torch.randn(7, device=self.device)
1033+
self.check(M(), (x,), dynamic_shapes=({0: Dim.DYNAMIC},))
1034+
9931035

9941036
if __name__ == "__main__":
9951037
from torch._inductor.test_case import run_tests

torch/_inductor/codegen/wrapper_fxir.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
from torch._inductor.codecache import LambdaFuture, PyCodeCache
2424
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
2525
from torch._inductor.select_algorithm import extern_kernels # noqa: F401
26-
from torch._inductor.utils import convert_shape_to_symint, sympy_product
26+
from torch._inductor.utils import (
27+
convert_shape_to_symint,
28+
convert_to_symint,
29+
sympy_product,
30+
)
2731
from torch._inductor.virtualized import V
2832
from torch._library.triton import wrap_triton
2933
from torch.fx import GraphModule
@@ -89,8 +93,10 @@ class SymbolBuffer(CodegenSymbol):
8993
def get_name(self) -> str:
9094
return str(self.symbol)
9195

92-
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
93-
return self.symbol
96+
def get_example(self) -> Union[torch.Tensor, torch.SymInt]:
97+
sym_int = convert_to_symint(self.symbol)
98+
assert isinstance(sym_int, torch.SymInt)
99+
return sym_int
94100

95101

96102
CodegenBuffer = Union[BufferLike, SymbolBuffer]
@@ -386,6 +392,13 @@ def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer:
386392
else:
387393
raise NotImplementedError(f"Unable to extract buffer from node: {node}")
388394

395+
def _generate_size_proxy(
396+
self, node: torch.fx.Node, expr: sympy.Expr
397+
) -> torch.fx.Proxy:
398+
proxy = torch.fx.Proxy(node, tracer=self.tracer)
399+
self.expr_to_proxy[expr] = proxy
400+
return proxy
401+
389402
def _generate_graph_inputs(self) -> None:
390403
"""
391404
Converts graph inputs to FX placeholders.
@@ -398,15 +411,22 @@ def _generate_graph_inputs(self) -> None:
398411
continue
399412

400413
# Introduce a new symbol for constant inputs.
414+
is_constant = isinstance(ir_node, (int, float, sympy.Integer, sympy.Float))
401415
buffer = (
402416
SymbolBuffer(sympy.Symbol(name, is_integer=True))
403-
if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float))
417+
if is_constant
404418
else self._get_buffer(ir_node)
405419
)
406420
placeholder_node = self.gm.graph.placeholder(buffer.get_name())
407-
placeholder_node.meta["val"] = buffer.get_example()
421+
placeholder_node.meta["val"] = (
422+
ir_node if is_constant else buffer.get_example()
423+
)
408424
self._record_allocation(buffer, placeholder_node)
409425

426+
# Record symbol definitions for dynamic shapes.
427+
if isinstance(ir_node, sympy.Symbol):
428+
self._generate_size_proxy(placeholder_node, ir_node)
429+
410430
def _generate_graph_input_shapes(self) -> None:
411431
"""
412432
Generate nodes creating symints that are part of graph input
@@ -421,8 +441,7 @@ def _codegen_symbol(
421441
) -> None:
422442
def codegen_proxy() -> torch.fx.Proxy:
423443
size_node = self.gm.graph.call_function(target, (base_node, dim))
424-
size_proxy = torch.fx.Proxy(size_node, tracer=self.tracer)
425-
self.expr_to_proxy[sym_or_exp] = size_proxy
444+
size_proxy = self._generate_size_proxy(size_node, sym_or_exp)
426445
return size_proxy
427446

428447
if isinstance(sym_or_exp, sympy.Symbol):
@@ -475,9 +494,7 @@ def codegen_proxy() -> torch.fx.Proxy:
475494
undefined_symbol_expr
476495
]
477496

478-
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
479-
name = node.name
480-
ir_node = self.graph_inputs.get(name)
497+
for ir_node in self.graph_inputs.values():
481498
if isinstance(ir_node, ir.TensorBox):
482499
buffer = self._get_buffer(ir_node)
483500
placeholder_node = self.buffer_to_node[buffer.get_name()]
@@ -504,6 +521,10 @@ def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]:
504521
Does nothing if no such transformations are present.
505522
"""
506523

524+
if isinstance(node, ir.ShapeAsConstantBuffer):
525+
# Generate FX nodes to compute the shape expression.
526+
return self._sympy_interp(node.expr).node
527+
507528
def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]:
508529
if isinstance(node, (ir.Buffer, WorkspaceArg)):
509530
return node
@@ -539,7 +560,9 @@ def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]:
539560
buffer = generate_to_buffer(node)
540561
return self.buffer_to_node[buffer.get_name()] if buffer is not None else None
541562

542-
def _generate_output(self) -> None:
563+
def _generate_outputs(
564+
self,
565+
) -> Union[Optional[torch.fx.Node], list[Optional[torch.fx.Node]]]:
543566
"""
544567
Generate FX IR for graph outputs.
545568
"""
@@ -554,7 +577,7 @@ def _generate_output(self) -> None:
554577
else output_nodes
555578
)
556579

557-
self.gm.graph.output(output_value)
580+
return output_value
558581

559582
def _generate_subgm_getattrs(self) -> None:
560583
"""
@@ -614,7 +637,9 @@ def generate(self) -> torch.fx.GraphModule:
614637
)
615638
)
616639

617-
self._generate_output()
640+
output = self._generate_outputs()
641+
642+
self.gm.graph.output(output)
618643
self.gm.recompile()
619644
return self.gm
620645

torch/_inductor/ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4220,7 +4220,7 @@ def get_name(self) -> str:
42204220
assert self.name, self
42214221
return self.name
42224222

4223-
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
4223+
def get_example(self) -> Union[torch.Tensor, torch.SymInt]:
42244224
if isinstance(self.layout, Layout):
42254225
return self.layout.get_example()
42264226
raise NotImplementedError(type(self.layout).__name__)

0 commit comments

Comments
 (0)