|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import hashlib |
4 | | -from typing import Any, Optional, TYPE_CHECKING |
| 4 | +from typing import Any, Optional, TYPE_CHECKING, Union |
5 | 5 |
|
6 | 6 | import sympy # noqa: TC002 |
7 | 7 |
|
|
21 | 21 | from collections.abc import Callable, Sequence |
22 | 22 |
|
23 | 23 | from ..ir import IRNode |
| 24 | + from ..ops_handler import ReductionType |
24 | 25 | from ..scheduler import BaseSchedulerNode |
25 | 26 |
|
26 | 27 |
|
@@ -747,22 +748,95 @@ def store( |
747 | 748 | if self.use_masked_ops is None: |
748 | 749 | self.use_masked_ops = self._determine_masked_ops_for_kernel() |
749 | 750 |
|
750 | | - index_str, needs_flatten = self._get_index_expr(index) |
| 751 | + # Check if this is a scalar output (reduction to scalar) |
| 752 | + # Only shape () is a true scalar, not (1,) which is a 1-element tensor |
| 753 | + try: |
| 754 | + buf = V.graph.get_buffer(name) |
| 755 | + output_shape = buf.get_size() |
| 756 | + is_scalar = len(output_shape) == 0 |
| 757 | + except Exception: |
| 758 | + is_scalar = False |
| 759 | + |
| 760 | + if is_scalar: |
| 761 | + # For scalar outputs, use [...] to assign the entire scalar |
| 762 | + store_expr = f"{out}[...] = {value}" |
| 763 | + else: |
| 764 | + index_str, needs_flatten = self._get_index_expr(index) |
751 | 765 |
|
752 | | - # Build store expression using string concatenation |
753 | | - use_masked = index_str == "..." and not needs_flatten and self.use_masked_ops |
| 766 | + # Build store expression using string concatenation |
| 767 | + use_masked = ( |
| 768 | + index_str == "..." and not needs_flatten and self.use_masked_ops |
| 769 | + ) |
754 | 770 |
|
755 | | - if use_masked: |
756 | | - # GPU masked store: flatten tensor and apply per-tensor mask |
757 | | - mask_var = self._get_or_create_mask(name) |
758 | | - store_expr = ( |
759 | | - f"pltriton.store({out}.at[pl.ds(block_size)], {value}, mask={mask_var})" |
| 771 | + if use_masked: |
| 772 | + # GPU masked store: flatten tensor and apply per-tensor mask |
| 773 | + mask_var = self._get_or_create_mask(name) |
| 774 | + store_expr = f"pltriton.store({out}.at[pl.ds(block_size)], {value}, mask={mask_var})" |
| 775 | + else: |
| 776 | + # Direct indexed assignment |
| 777 | + store_expr = f"{out}[{index_str}] = {value}" |
| 778 | + |
| 779 | + self.stores.writeline(store_expr) |
| 780 | + |
| 781 | + def reduction( |
| 782 | + self, |
| 783 | + dtype: torch.dtype, |
| 784 | + src_dtype: torch.dtype, |
| 785 | + reduction_type: ReductionType, |
| 786 | + value: Union[CSEVariable, tuple[CSEVariable, ...]], |
| 787 | + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: # type: ignore[override] |
| 788 | + """ |
| 789 | + Generate code for reduction operations in JAX/Pallas. |
| 790 | +
|
| 791 | + Reductions in Pallas work by: |
| 792 | + 1. Loading the input data into the kernel |
| 793 | + 2. Applying JAX reduction operations (jnp.sum, jnp.max, etc.) |
| 794 | + 3. Storing the reduced result |
| 795 | +
|
| 796 | + The reduction happens over the loaded block of data. |
| 797 | + """ |
| 798 | + assert self.inside_reduction |
| 799 | + |
| 800 | + if isinstance(value, tuple): |
| 801 | + raise Unsupported( |
| 802 | + "Tuple reductions (e.g., welford_combine) not supported in Pallas backend" |
760 | 803 | ) |
| 804 | + |
| 805 | + # Check if this reduction is already cached |
| 806 | + cache_key = (src_dtype, reduction_type, value) |
| 807 | + if cache_key in self.cse.reduction_cache: |
| 808 | + return self.cse.reduction_cache[cache_key] |
| 809 | + |
| 810 | + # Map reduction types to JAX functions |
| 811 | + reduction_ops = { |
| 812 | + "sum": "jnp.sum", |
| 813 | + "prod": "jnp.prod", # CPU only - not supported in Pallas GPU (Triton) backend |
| 814 | + "max": "jnp.max", |
| 815 | + "min": "jnp.min", |
| 816 | + "any": "jnp.any", |
| 817 | + } |
| 818 | + |
| 819 | + if reduction_type == "xor_sum": |
| 820 | + reduction_expr = f"jnp.bitwise_xor.reduce({value})" |
| 821 | + elif reduction_type in reduction_ops: |
| 822 | + # Apply reduction over all axes to get scalar result |
| 823 | + reduction_expr = f"{reduction_ops[reduction_type]}({value})" |
761 | 824 | else: |
762 | | - # Direct indexed assignment |
763 | | - store_expr = f"{out}[{index_str}] = {value}" |
| 825 | + raise Unsupported( |
| 826 | + f"Reduction type '{reduction_type}' not yet supported in Pallas backend. " |
| 827 | + f"Supported types: {list(reduction_ops.keys())}, xor_sum" |
| 828 | + ) |
764 | 829 |
|
765 | | - self.stores.writeline(store_expr) |
| 830 | + # Generate CSE variable for the reduction result |
| 831 | + result = self.cse.generate( |
| 832 | + self.compute, |
| 833 | + reduction_expr, |
| 834 | + dtype=dtype, |
| 835 | + ) |
| 836 | + |
| 837 | + # Cache the result |
| 838 | + self.cse.reduction_cache[cache_key] = result |
| 839 | + return result |
766 | 840 |
|
767 | 841 | @staticmethod |
768 | 842 | def _buffer_is_contiguous(buffer_name: str) -> bool: |
@@ -1074,8 +1148,9 @@ class PallasScheduling(SIMDScheduling): |
1074 | 1148 |
|
1075 | 1149 | @classmethod |
1076 | 1150 | def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: |
1077 | | - # Start minimal: no special features advertised |
1078 | | - return OrderedSet() |
| 1151 | + # Pallas/JAX can handle reductions to single elements efficiently |
| 1152 | + # without requiring split reductions |
| 1153 | + return OrderedSet([BackendFeature.REDUCE_TO_SINGLE_ELEMENT]) |
1079 | 1154 |
|
1080 | 1155 | def define_kernel( |
1081 | 1156 | self, |
|
0 commit comments