Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/inductor/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_parse_reduction_hint(self):
)

@config.patch("fx_graph_remote_cache", False)
@config.patch("partitioned_scatter_enabled", False)
def test_atomic_add(self):
@torch.compile
def f(lhs, index, rhs):
Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,9 @@ def decide_worker_start_method() -> str:

_micro_pipeline_tp: bool = False

# Enable/disable partitioned scatter optimization for atomic add kernels
# this will improve kernel performance at cost of memory usage.
partitioned_scatter_enabled = os.environ.get("TORCHINDUCTOR_PARTITIONED_SCATTER_ENABLED", "1") == "1"

class _collective:
auto_select: bool = False
Expand Down
5 changes: 5 additions & 0 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from .pre_grad import is_same_dict, save_inductor_dict
from .reinplace import reinplace_inplaceable_ops
from .split_cat import POST_GRAD_PATTERNS
from .reduced_atomic_contention import partitioned_scatter_optimization_pass


_T = TypeVar("_T")
Expand Down Expand Up @@ -140,6 +141,10 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
GraphTransformObserver(gm, f"pass_pattern_{i}").apply_graph_pass(
patterns.apply
)
if config.partitioned_scatter_enabled:
GraphTransformObserver(gm, "partitioned_scatter_optimization").apply_graph_pass(
partitioned_scatter_optimization_pass
)
for pass_name in config.post_grad_fusion_options:
# skip all patterns for group batch fusions or quantization patterns
if pass_name in POST_GRAD_FUSIONS or pass_name in OPTIMUS_EXCLUDE_POST_GRAD:
Expand Down
Loading