Skip to content

Commit 762273e

Browse files
vinithakvpytorchmergebot
authored andcommitted
Move pointwise_scatter optimization to joint_graph stage from post_grad (pytorch#165463)
Fixes pytorch#129449 Pull Request resolved: pytorch#165463 Approved by: https://github.com/eellison
1 parent 6edf2aa commit 762273e

File tree

2 files changed

+91
-91
lines changed

2 files changed

+91
-91
lines changed

torch/_inductor/fx_passes/joint_graph.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,3 +957,92 @@ def repl(inp, other):
957957
pass_dict=pass_patterns[1],
958958
extra_check=_other_is_broadcasted_in_dim,
959959
)(div_softmax_pattern)
960+
961+
962+
def scatter_upon_const_tensor_extra_check(m):
963+
if not config.optimize_scatter_upon_const_tensor:
964+
return False
965+
full_shape = m.kwargs["shape"]
966+
selector = m.kwargs["selector"]
967+
dim = m.kwargs["dim"]
968+
if dim < 0:
969+
dim += len(full_shape)
970+
971+
selector_ft = selector.meta["val"]
972+
assert selector_ft.dim() == len(full_shape)
973+
974+
for idx, select_sz, full_sz in zip(
975+
itertools.count(), selector_ft.shape, full_shape
976+
):
977+
if idx == dim:
978+
continue
979+
980+
# TODO: the pattern can be updated to support the case that index tensor
981+
# is shorter. But that will need a more complex condition expression
982+
# especially for multi-dimensional tensors.
983+
# Skip it for now.
984+
if isinstance(full_sz, torch.fx.Node):
985+
full_sz = full_sz.meta["val"]
986+
if select_sz < full_sz:
987+
return False
988+
989+
# Actually we can support small size larger than 1. It would be a bit
990+
# tedious. E.g., we load all the index values (not many) and compare
991+
# them with the position in tensor to decide what value to return.
992+
return selector_ft.size(dim) == 1
993+
994+
995+
@register_graph_pattern(
996+
CallFunction(
997+
aten.scatter.value,
998+
CallFunction(
999+
aten.full,
1000+
KeywordArg("shape"),
1001+
KeywordArg("background_val"),
1002+
dtype=KeywordArg("dtype"),
1003+
),
1004+
KeywordArg("dim"),
1005+
KeywordArg("selector"),
1006+
KeywordArg("val"), # scalar value
1007+
),
1008+
# pyrefly: ignore [bad-argument-type]
1009+
pass_dict=patterns,
1010+
extra_check=scatter_upon_const_tensor_extra_check,
1011+
)
1012+
def scatter_upon_const_tensor(
1013+
match: Match, shape, background_val, dtype, dim, selector, val
1014+
):
1015+
"""
1016+
Match the pattern of full+scatter into a pointwise operation in joint graph.
1017+
1018+
TODO: Right now the scatter value must be a scalar. But we could support it
1019+
when it is a tensor as well.
1020+
"""
1021+
from torch._inductor import metrics
1022+
1023+
# pyrefly: ignore # bad-assignment
1024+
metrics.num_matches_for_scatter_upon_const_tensor += 1
1025+
1026+
# Create a replacement that uses torch.where for the pointwise operation
1027+
def repl_fn(shape, background_val, dim, selector, val):
1028+
# Create a tensor of indices for the scatter dimension
1029+
length = shape[dim]
1030+
indices = torch.arange(length, device=selector.device, dtype=torch.int64)
1031+
1032+
# Reshape indices to have size 'length' at dim, then broadcast
1033+
view_shape = [1] * len(shape)
1034+
view_shape[dim] = length
1035+
indices_view = indices.view(*view_shape)
1036+
1037+
# Broadcast selector to match full tensor shape
1038+
selector_expanded = selector.expand(shape)
1039+
1040+
# Create a mask for where to scatter
1041+
mask = selector_expanded == indices_view
1042+
1043+
# Use torch.where to implement the scatter pointwise operation
1044+
return torch.where(mask, val, background_val)
1045+
1046+
# replace the scatter operation with pointwise equivalent
1047+
# pyrefly: ignore [bad-argument-type]
1048+
match.replace_by_example(repl_fn, [shape, background_val, dim, selector, val])

torch/_inductor/fx_passes/post_grad.py

Lines changed: 2 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
from torch._decomp import register_decomposition
1717
from torch._dynamo.utils import counters
1818
from torch._inductor import comms
19-
from torch._inductor.virtualized import ops
19+
from torch._inductor.virtualized import ops # noqa: F401
2020
from torch._logging import trace_structured
2121
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
2222
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
2323
from torch.utils._ordered_set import OrderedSet
2424

25-
from .. import config, ir, pattern_matcher
25+
from .. import config, ir, pattern_matcher # noqa: F401
2626
from ..codegen.common import custom_backend_passes
2727
from ..comms import remove_fsdp2_unsharded_param_graph_input_usage
2828
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
@@ -802,95 +802,6 @@ def is_valid_mm_plus_mm(match: Match):
802802
return True
803803

804804

805-
def scatter_upon_const_tensor_extra_check(m):
806-
if not config.optimize_scatter_upon_const_tensor:
807-
return False
808-
full_shape = m.kwargs["shape"]
809-
selector = m.kwargs["selector"]
810-
dim = m.kwargs["dim"]
811-
if dim < 0:
812-
dim += len(full_shape)
813-
814-
selector_ft = selector.meta["val"]
815-
assert selector_ft.dim() == len(full_shape)
816-
817-
for idx, select_sz, full_sz in zip(
818-
itertools.count(), selector_ft.shape, full_shape
819-
):
820-
if idx == dim:
821-
continue
822-
823-
# TODO: the pattern can be updated to support the case that index tensor
824-
# is shorter. But that will need a more complex condition expression
825-
# especially for multi-dimensional tensors.
826-
# Skip it for now.
827-
if isinstance(full_sz, fx.Node):
828-
full_sz = full_sz.meta["val"]
829-
if select_sz < full_sz:
830-
return False
831-
832-
# Actually we can support small size larger than 1. It would be a bit
833-
# tedius. E.g., we load all the index values (not many) and compare
834-
# them with the position in tensor to decide what value to return.
835-
return selector_ft.size(dim) == 1
836-
837-
838-
@register_lowering_pattern(
839-
CallFunction(
840-
aten.scatter.value,
841-
CallFunction(
842-
aten.full,
843-
KeywordArg("shape"),
844-
KeywordArg("background_val"),
845-
dtype=KeywordArg("dtype"),
846-
),
847-
KeywordArg("dim"),
848-
KeywordArg("selector"),
849-
KeywordArg("val"), # scalar value
850-
),
851-
extra_check=scatter_upon_const_tensor_extra_check,
852-
)
853-
def scatter_upon_const_tensor(
854-
match: Match, shape, background_val, dtype, dim, selector, val
855-
):
856-
"""
857-
Match the pattern of full+scatter into a pointwise.
858-
859-
TODO: Right now the scatter value must be a scalar. But we could support it
860-
when it is a tensor as well.
861-
"""
862-
from torch._inductor import metrics
863-
864-
# Check if inputs are tensors instead of inductor IR nodes
865-
if isinstance(selector, torch.Tensor):
866-
# Return a fake tensor with the proper shape that this operator is intended to return
867-
device = selector.device if hasattr(selector, "device") else torch.device("cpu")
868-
return torch.empty(shape, dtype=dtype, device=device)
869-
870-
# pyrefly: ignore [bad-assignment]
871-
metrics.num_matches_for_scatter_upon_const_tensor += 1
872-
873-
selector_loader = selector.make_loader()
874-
875-
def inner_fn(idx):
876-
selector_idx = list(idx)
877-
selector_idx[dim] = 0
878-
879-
selector = selector_loader(selector_idx)
880-
return ops.where(
881-
selector == ops.index_expr(idx[dim], torch.int64),
882-
ops.constant(val, dtype),
883-
ops.constant(background_val, dtype),
884-
)
885-
886-
return ir.Pointwise.create(
887-
device=selector.get_device(),
888-
dtype=dtype,
889-
inner_fn=inner_fn,
890-
ranges=shape,
891-
)
892-
893-
894805
@register_lowering_pattern(
895806
CallFunction(
896807
aten.add,

0 commit comments

Comments
 (0)