Skip to content

Commit 35dae27

Browse files
oulgenpytorchmergebot
authored andcommitted
[pallas backend] support reductions (pytorch#167953)
Pull Request resolved: pytorch#167953 Approved by: https://github.com/jansel ghstack dependencies: pytorch#167947, pytorch#167951
1 parent 9ff1922 commit 35dae27

File tree

2 files changed

+144
-14
lines changed

2 files changed

+144
-14
lines changed

test/inductor/test_pallas.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,61 @@ def fn(a, b):
690690
expected = fn(a, b)
691691
self.assertEqual(result, expected)
692692

693+
def test_sum_reduction(self):
694+
"""Test sum reduction."""
695+
696+
def fn(x):
697+
return x.sum()
698+
699+
compiled = self._compile(fn)
700+
701+
x = torch.randn(16, device=self.DEVICE)
702+
result = compiled(x)
703+
expected = fn(x)
704+
self.assertEqual(result, expected)
705+
706+
def test_max_reduction(self):
707+
"""Test max reduction."""
708+
709+
def fn(x):
710+
return x.max()
711+
712+
compiled = self._compile(fn)
713+
714+
x = torch.randn(16, device=self.DEVICE)
715+
result = compiled(x)
716+
expected = fn(x)
717+
self.assertEqual(result, expected)
718+
719+
def test_min_reduction(self):
720+
"""Test min reduction."""
721+
722+
def fn(x):
723+
return x.min()
724+
725+
compiled = self._compile(fn)
726+
727+
x = torch.randn(16, device=self.DEVICE)
728+
result = compiled(x)
729+
expected = fn(x)
730+
self.assertEqual(result, expected)
731+
732+
def test_prod_reduction(self):
733+
"""Test prod reduction."""
734+
if self.DEVICE == "cuda":
735+
self.skipTest("prod reduction not supported in Pallas GPU (Triton) backend")
736+
737+
def fn(x):
738+
# Use smaller values to avoid overflow
739+
return (x * 0.1).prod()
740+
741+
compiled = self._compile(fn)
742+
743+
x = torch.randn(16, device=self.DEVICE)
744+
result = compiled(x)
745+
expected = fn(x)
746+
self.assertEqual(result, expected)
747+
693748

694749
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
695750
class PallasTestsCUDA(PallasTestsMixin, TestCase):

torch/_inductor/codegen/pallas.py

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import hashlib
4-
from typing import Any, Optional, TYPE_CHECKING
4+
from typing import Any, Optional, TYPE_CHECKING, Union
55

66
import sympy # noqa: TC002
77

@@ -21,6 +21,7 @@
2121
from collections.abc import Callable, Sequence
2222

2323
from ..ir import IRNode
24+
from ..ops_handler import ReductionType
2425
from ..scheduler import BaseSchedulerNode
2526

2627

@@ -747,22 +748,95 @@ def store(
747748
if self.use_masked_ops is None:
748749
self.use_masked_ops = self._determine_masked_ops_for_kernel()
749750

750-
index_str, needs_flatten = self._get_index_expr(index)
751+
# Check if this is a scalar output (reduction to scalar)
752+
# Only shape () is a true scalar, not (1,) which is a 1-element tensor
753+
try:
754+
buf = V.graph.get_buffer(name)
755+
output_shape = buf.get_size()
756+
is_scalar = len(output_shape) == 0
757+
except Exception:
758+
is_scalar = False
759+
760+
if is_scalar:
761+
# For scalar outputs, use [...] to assign the entire scalar
762+
store_expr = f"{out}[...] = {value}"
763+
else:
764+
index_str, needs_flatten = self._get_index_expr(index)
751765

752-
# Build store expression using string concatenation
753-
use_masked = index_str == "..." and not needs_flatten and self.use_masked_ops
766+
# Build store expression using string concatenation
767+
use_masked = (
768+
index_str == "..." and not needs_flatten and self.use_masked_ops
769+
)
754770

755-
if use_masked:
756-
# GPU masked store: flatten tensor and apply per-tensor mask
757-
mask_var = self._get_or_create_mask(name)
758-
store_expr = (
759-
f"pltriton.store({out}.at[pl.ds(block_size)], {value}, mask={mask_var})"
771+
if use_masked:
772+
# GPU masked store: flatten tensor and apply per-tensor mask
773+
mask_var = self._get_or_create_mask(name)
774+
store_expr = f"pltriton.store({out}.at[pl.ds(block_size)], {value}, mask={mask_var})"
775+
else:
776+
# Direct indexed assignment
777+
store_expr = f"{out}[{index_str}] = {value}"
778+
779+
self.stores.writeline(store_expr)
780+
781+
def reduction(
782+
self,
783+
dtype: torch.dtype,
784+
src_dtype: torch.dtype,
785+
reduction_type: ReductionType,
786+
value: Union[CSEVariable, tuple[CSEVariable, ...]],
787+
) -> Union[CSEVariable, tuple[CSEVariable, ...]]: # type: ignore[override]
788+
"""
789+
Generate code for reduction operations in JAX/Pallas.
790+
791+
Reductions in Pallas work by:
792+
1. Loading the input data into the kernel
793+
2. Applying JAX reduction operations (jnp.sum, jnp.max, etc.)
794+
3. Storing the reduced result
795+
796+
The reduction happens over the loaded block of data.
797+
"""
798+
assert self.inside_reduction
799+
800+
if isinstance(value, tuple):
801+
raise Unsupported(
802+
"Tuple reductions (e.g., welford_combine) not supported in Pallas backend"
760803
)
804+
805+
# Check if this reduction is already cached
806+
cache_key = (src_dtype, reduction_type, value)
807+
if cache_key in self.cse.reduction_cache:
808+
return self.cse.reduction_cache[cache_key]
809+
810+
# Map reduction types to JAX functions
811+
reduction_ops = {
812+
"sum": "jnp.sum",
813+
"prod": "jnp.prod", # CPU only - not supported in Pallas GPU (Triton) backend
814+
"max": "jnp.max",
815+
"min": "jnp.min",
816+
"any": "jnp.any",
817+
}
818+
819+
if reduction_type == "xor_sum":
820+
reduction_expr = f"jnp.bitwise_xor.reduce({value})"
821+
elif reduction_type in reduction_ops:
822+
# Apply reduction over all axes to get scalar result
823+
reduction_expr = f"{reduction_ops[reduction_type]}({value})"
761824
else:
762-
# Direct indexed assignment
763-
store_expr = f"{out}[{index_str}] = {value}"
825+
raise Unsupported(
826+
f"Reduction type '{reduction_type}' not yet supported in Pallas backend. "
827+
f"Supported types: {list(reduction_ops.keys())}, xor_sum"
828+
)
764829

765-
self.stores.writeline(store_expr)
830+
# Generate CSE variable for the reduction result
831+
result = self.cse.generate(
832+
self.compute,
833+
reduction_expr,
834+
dtype=dtype,
835+
)
836+
837+
# Cache the result
838+
self.cse.reduction_cache[cache_key] = result
839+
return result
766840

767841
@staticmethod
768842
def _buffer_is_contiguous(buffer_name: str) -> bool:
@@ -1074,8 +1148,9 @@ class PallasScheduling(SIMDScheduling):
10741148

10751149
@classmethod
10761150
def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]:
1077-
# Start minimal: no special features advertised
1078-
return OrderedSet()
1151+
# Pallas/JAX can handle reductions to single elements efficiently
1152+
# without requiring split reductions
1153+
return OrderedSet([BackendFeature.REDUCE_TO_SINGLE_ELEMENT])
10791154

10801155
def define_kernel(
10811156
self,

0 commit comments

Comments
 (0)