|
16 | 16 | from torch._decomp import register_decomposition |
17 | 17 | from torch._dynamo.utils import counters |
18 | 18 | from torch._inductor import comms |
19 | | -from torch._inductor.virtualized import ops |
| 19 | +from torch._inductor.virtualized import ops # noqa: F401 |
20 | 20 | from torch._logging import trace_structured |
21 | 21 | from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype |
22 | 22 | from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq |
23 | 23 | from torch.utils._ordered_set import OrderedSet |
24 | 24 |
|
25 | | -from .. import config, ir, pattern_matcher |
| 25 | +from .. import config, ir, pattern_matcher # noqa: F401 |
26 | 26 | from ..codegen.common import custom_backend_passes |
27 | 27 | from ..comms import remove_fsdp2_unsharded_param_graph_input_usage |
28 | 28 | from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage |
@@ -802,95 +802,6 @@ def is_valid_mm_plus_mm(match: Match): |
802 | 802 | return True |
803 | 803 |
|
804 | 804 |
|
805 | | -def scatter_upon_const_tensor_extra_check(m): |
806 | | - if not config.optimize_scatter_upon_const_tensor: |
807 | | - return False |
808 | | - full_shape = m.kwargs["shape"] |
809 | | - selector = m.kwargs["selector"] |
810 | | - dim = m.kwargs["dim"] |
811 | | - if dim < 0: |
812 | | - dim += len(full_shape) |
813 | | - |
814 | | - selector_ft = selector.meta["val"] |
815 | | - assert selector_ft.dim() == len(full_shape) |
816 | | - |
817 | | - for idx, select_sz, full_sz in zip( |
818 | | - itertools.count(), selector_ft.shape, full_shape |
819 | | - ): |
820 | | - if idx == dim: |
821 | | - continue |
822 | | - |
823 | | - # TODO: the pattern can be updated to support the case that index tensor |
824 | | - # is shorter. But that will need a more complex condition expression |
825 | | - # especially for multi-dimensional tensors. |
826 | | - # Skip it for now. |
827 | | - if isinstance(full_sz, fx.Node): |
828 | | - full_sz = full_sz.meta["val"] |
829 | | - if select_sz < full_sz: |
830 | | - return False |
831 | | - |
832 | | - # Actually we can support small size larger than 1. It would be a bit |
833 | | - # tedius. E.g., we load all the index values (not many) and compare |
834 | | - # them with the position in tensor to decide what value to return. |
835 | | - return selector_ft.size(dim) == 1 |
836 | | - |
837 | | - |
838 | | -@register_lowering_pattern( |
839 | | - CallFunction( |
840 | | - aten.scatter.value, |
841 | | - CallFunction( |
842 | | - aten.full, |
843 | | - KeywordArg("shape"), |
844 | | - KeywordArg("background_val"), |
845 | | - dtype=KeywordArg("dtype"), |
846 | | - ), |
847 | | - KeywordArg("dim"), |
848 | | - KeywordArg("selector"), |
849 | | - KeywordArg("val"), # scalar value |
850 | | - ), |
851 | | - extra_check=scatter_upon_const_tensor_extra_check, |
852 | | -) |
853 | | -def scatter_upon_const_tensor( |
854 | | - match: Match, shape, background_val, dtype, dim, selector, val |
855 | | -): |
856 | | - """ |
857 | | - Match the pattern of full+scatter into a pointwise. |
858 | | -
|
859 | | - TODO: Right now the scatter value must be a scalar. But we could support it |
860 | | - when it is a tensor as well. |
861 | | - """ |
862 | | - from torch._inductor import metrics |
863 | | - |
864 | | - # Check if inputs are tensors instead of inductor IR nodes |
865 | | - if isinstance(selector, torch.Tensor): |
866 | | - # Return a fake tensor with the proper shape that this operator is intended to return |
867 | | - device = selector.device if hasattr(selector, "device") else torch.device("cpu") |
868 | | - return torch.empty(shape, dtype=dtype, device=device) |
869 | | - |
870 | | - # pyrefly: ignore [bad-assignment] |
871 | | - metrics.num_matches_for_scatter_upon_const_tensor += 1 |
872 | | - |
873 | | - selector_loader = selector.make_loader() |
874 | | - |
875 | | - def inner_fn(idx): |
876 | | - selector_idx = list(idx) |
877 | | - selector_idx[dim] = 0 |
878 | | - |
879 | | - selector = selector_loader(selector_idx) |
880 | | - return ops.where( |
881 | | - selector == ops.index_expr(idx[dim], torch.int64), |
882 | | - ops.constant(val, dtype), |
883 | | - ops.constant(background_val, dtype), |
884 | | - ) |
885 | | - |
886 | | - return ir.Pointwise.create( |
887 | | - device=selector.get_device(), |
888 | | - dtype=dtype, |
889 | | - inner_fn=inner_fn, |
890 | | - ranges=shape, |
891 | | - ) |
892 | | - |
893 | | - |
894 | 805 | @register_lowering_pattern( |
895 | 806 | CallFunction( |
896 | 807 | aten.add, |
|
0 commit comments