diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index fab65cb0..14412361 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -54,6 +54,7 @@ from .type_propagation import LiteralType from .type_propagation import NumericType from .type_propagation import SequenceType +from .type_propagation import StackTensorType from .type_propagation import TensorType from .type_propagation import TileIndexType from .type_propagation import TypeInfo @@ -321,12 +322,14 @@ def build_rolled_reductions(self) -> None: graph_to_info = {} allow_loop = False - # First, check if any graph contains matmul with rdim + # First, check if any graph contains matmul or dev_prts stacking with rdim # If so, we can't roll any graphs in this reduction dimension can_roll_graphs = True for graph_info in self.graphs: roller = ReductionRoller(self, rdim, {}) - if roller.has_matmul_with_rdim(graph_info.graph): + if roller.has_matmul_with_rdim( + graph_info.graph + ) or roller.has_stack_tensor_with_rdim(graph_info.graph): can_roll_graphs = False break @@ -870,7 +873,9 @@ def visit_Assign(self, node: ast.Assign) -> None: assert isinstance(target.value, ExtendedAST) assert target.value._type_info is not None target_origin = target.value._type_info.origin # pyright: ignore[reportOptionalMemberAccess] - if not target_origin.is_host(): + if not target_origin.is_host() and not isinstance( + target.value._type_info, StackTensorType + ): # Get the variable name for the error message var_name = ( target.value.id @@ -895,7 +900,9 @@ def _assign_subscript(self, target: ast.Subscript, val: object) -> None: assert isinstance(target.value, ExtendedAST) assert target.value._type_info is not None target_origin = target.value._type_info.origin - assert target_origin.is_host() + assert target_origin.is_host() or isinstance( + target.value._type_info, StackTensorType + ) return hl.store( self.visit(target.value), # pyright: ignore[reportArgumentType] @@ -928,6 +935,8 @@ def visit_Subscript(self, node: ast.Subscript) -> object: if isinstance(node.slice, ast.Constant): return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue] raise exc.InvalidSequenceSubscription(node.slice) + if isinstance(type_info, StackTensorType): + return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType] if type_info is not None and type_info.origin.is_host(): return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType] return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType] diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 0d2d7cc2..0efb8470 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -8,6 +8,7 @@ import sympy import torch +from torch._inductor.utils import triton_type from .. import exc from .._compat import get_tensor_descriptor_fn_name @@ -19,10 +20,15 @@ from .variable_origin import BlockSizeOrigin if TYPE_CHECKING: + from collections.abc import Sequence + from ..runtime.config import Config from .device_function import TensorDescriptorArg from .inductor_lowering import CodegenState + SymIntLike = torch.SymInt | int + ShapeLike = Sequence[SymIntLike] + class IndexingStrategy: def codegen_load( @@ -296,6 +302,147 @@ def codegen_store( ) +class StackIndexingStrategy: + """ + Generate pointer math for stacking load/store to several device memory pointers sharing the same indexing. + + offset, mask are calculated for the tensor_like template tensor and then broadcasted to each dev_ptr + , with the results stacked. + + e.g. for a 1D offset tensor and a 1D dev_ptr array, the stack offset is: + stack_offset = dev_ptrs[:, None] + offset[None, :] + + """ + + @staticmethod + def get_broadcast_str( + stack_shape: ShapeLike, + subscript_shape: ShapeLike, + ) -> tuple[str, str]: + """ + Args: + stack_shape: shape of the dev_ptr tensor. + subscript_shape: shape of subscription for each individual tensor. + + Returns: + the broadcast str for dev_ptrs and individual tensor offset. + """ + stack_broadcast_keys = [":" for _ in stack_shape] + [ + "None" for _ in subscript_shape + ] + stack_broadcast = f"[{', '.join(stack_broadcast_keys)}]" + tensor_broadcast_keys = ["None" for _ in stack_shape] + [ + ":" for _ in subscript_shape + ] + tensor_broadcast = f"[{', '.join(tensor_broadcast_keys)}]" + + return stack_broadcast, tensor_broadcast + + @staticmethod + def get_mask_expr( + state: CodegenState, + indexing: SubscriptIndexing, + stack_shape: ShapeLike, + subscript_shape: ShapeLike, + ) -> ast.AST | None: + stack_broadcast, tensor_broadcast = StackIndexingStrategy.get_broadcast_str( + stack_shape, subscript_shape + ) + + mask_exprs = [] + dev_ptr_mask_exprs = [] + # Generate Mask + + for dim, size in enumerate(stack_shape): + if ( + index := CompileEnvironment.current().get_block_id(size) + ) is not None and (mask_var := state.codegen.mask_var(index)) is not None: + expand = state.tile_strategy.expand_str(stack_shape, dim) + dev_ptr_mask_exprs.append(f"({mask_var}{expand})") + + if dev_ptr_mask_exprs: + dev_ptr_mask_expr = f"({'&'.join(dev_ptr_mask_exprs)})" + if len(dev_ptr_mask_exprs) < len(stack_shape): + dev_ptr_mask_expr = f"tl.broadcast_to({dev_ptr_mask_expr}, {state.tile_strategy.shape_str(stack_shape)})" + dev_ptr_mask_expr = f"({dev_ptr_mask_expr}){stack_broadcast}" + mask_exprs.append(dev_ptr_mask_expr) + + if indexing.has_mask(): + mask_exprs.append(f"(tensor_mask){tensor_broadcast}") + return expr_from_string( + "&".join(mask_exprs), tensor_mask=indexing.mask_expr + ) + if mask_exprs: + return expr_from_string("&".join(mask_exprs)) + return None + + @staticmethod + def codegen_load( + state: CodegenState, + stack_tensor: tuple[torch.Tensor, torch.Tensor], + dev_ptrs_ast: ast.AST, + subscript: list[object], + extra_mask: ast.AST | None, + ) -> ast.AST: + tensor_like, dev_ptrs = stack_tensor + indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask) + subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript) + stack_shape = [*dev_ptrs.size()] + + mask_expr = StackIndexingStrategy.get_mask_expr( + state, indexing, stack_shape, subscripts_shape + ) + extra = ", other=0" + if mask_expr is None: + mask_expr = expr_from_string("None") + extra = "" + + stack_broadcast, tensor_broadcast = StackIndexingStrategy.get_broadcast_str( + stack_shape, subscripts_shape + ) + + dtype = triton_type(tensor_like.dtype) + return expr_from_string( + f"tl.load((base.to(tl.pointer_type({dtype}))){stack_broadcast} + (offset){tensor_broadcast}, mask{extra})", + base=dev_ptrs_ast, + offset=indexing.index_expr, + mask=mask_expr, + ) + + @staticmethod + def codegen_store( + state: CodegenState, + stack_tensor: tuple[torch.Tensor, torch.Tensor], + dev_ptrs_ast: ast.AST, + subscript: list[object], + value: ast.AST, + extra_mask: ast.AST | None, + ) -> ast.AST: + tensor_like, dev_ptrs = stack_tensor + indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask) + subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript) + stack_shape = [*dev_ptrs.size()] + + mask_expr = StackIndexingStrategy.get_mask_expr( + state, indexing, stack_shape, subscripts_shape + ) + if mask_expr is None: + mask_expr = expr_from_string("None") + + stack_broadcast, tensor_broadcast = StackIndexingStrategy.get_broadcast_str( + stack_shape, subscripts_shape + ) + + dtype = triton_type(tensor_like.dtype) + return expr_from_string( + f"tl.store(base.to(tl.pointer_type({dtype})){stack_broadcast} + (offset){tensor_broadcast}, value, mask)", + base=dev_ptrs_ast, + value=value, + offset=indexing.index_expr, + mask=mask_expr, + ) + + class SubscriptIndexing(NamedTuple): index_expr: ast.AST mask_expr: ast.AST diff --git a/helion/_compiler/roll_reduction.py b/helion/_compiler/roll_reduction.py index ee2f5b3d..d1c8ec6d 100644 --- a/helion/_compiler/roll_reduction.py +++ b/helion/_compiler/roll_reduction.py @@ -6,6 +6,7 @@ import torch from torch.fx import map_arg +from ..language import _MEMORY_OPS from ..language._tracing_ops import _for_loop from ..language._tracing_ops import _get_symnode from ..language._tracing_ops import _host_tensor @@ -277,6 +278,35 @@ def is_matmul_with_rdim(node: torch.fx.Node) -> bool: return any(is_matmul_with_rdim(node) for node in graph.nodes) + def has_stack_tensor_with_rdim(self, graph: torch.fx.Graph) -> bool: + """Check if a graph contains stack tensors with rdim inputs.""" + + def is_stack_with_rdim(node: torch.fx.Node) -> bool: + """Check if a node is a stack dev_ptr with rdim inputs.""" + if node.op != "call_function": + return False + + if node.target not in _MEMORY_OPS: + return False + + host_tensor = node.args[0] + + if not isinstance(host_tensor, tuple): + return False + + # Check if stack dims have rdim + if len(host_tensor) == 2: + assert isinstance(host_tensor[1], torch.fx.Node) + stack = host_tensor[1].meta.get("val", None) + if isinstance(stack, torch.Tensor): + for size in stack.size(): + block_idx = CompileEnvironment.current().get_block_id(size) + if block_idx == self.rdim.block_id: + return True + return False + + return any(is_stack_with_rdim(node) for node in graph.nodes) + def process(self, graph: torch.fx.Graph) -> torch.fx.Graph: for node in graph.nodes: if self.should_go_in_inner_graph(node): diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 6f9f1a46..41070c19 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -27,6 +27,7 @@ from ..autotuner.config_spec import BlockSizeSpec from ..language._decorators import get_device_func_replacement from ..language._decorators import is_api_func +from ..language.stack_tensor import StackTensor from ..language.tile_proxy import Tile from ..language.tile_proxy import _CheckForIndexCalls from .ast_extension import ExtendedAST @@ -1294,6 +1295,86 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo: return self.element_types[attr] +class StackTensorType(ClassType): + element_types: dict[str, TypeInfo] # pyright: ignore[reportIncompatibleVariableOverride] + + def proxy(self) -> StackTensor: # pyright: ignore[reportIncompatibleMethodOverride] + with proxy_tensor.disable_proxy_modes_tracing(): + fake_mode = torch._C._unset_dispatch_mode( # pyright: ignore[reportAttributeAccessIssue] + torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue] + ) + try: + assert isinstance(self.element_types["tensor_like"], TensorType) + assert isinstance(self.element_types["dev_ptrs"], TensorType) + return StackTensor( + self.element_types["tensor_like"].proxy(), + self.element_types["dev_ptrs"].proxy(), + ) + finally: + assert fake_mode is not None + torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue] + + def merge(self, other: TypeInfo) -> TypeInfo: + if isinstance(other, StackTensorType): + self_elements = self.element_types + other_elements = other.element_types + if set(self_elements.keys()) == set(other_elements.keys()): + return StackTensorType( + origin=other.origin, + element_types={ + key: self_elements[key].merge(other_elements[key]) + for key in self_elements + }, + ) + return super().merge(other) + + def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]: + tensor_like_type = self.element_types["tensor_like"] + assert isinstance(tensor_like_type, TensorType) + size_like = tensor_like_type._device_indexing_size(key) + + dev_ptrs_type = self.element_types["dev_ptrs"] + assert isinstance(dev_ptrs_type, TensorType) + stack_size = list(dev_ptrs_type.fake_value.size()) + + return stack_size + size_like + + def propagate_setitem( + self, key: TypeInfo, value: TypeInfo, origin: Origin + ) -> TypeInfo: + if origin.is_host(): + warning(exc.TensorOperationInWrapper) + else: + lhs_shape = self._device_indexing_size(key) + lhs_rank = len(lhs_shape) + if isinstance(value, TensorType): + rhs_rank = value.fake_value.ndim + if lhs_rank != rhs_rank: + raise exc.RankMismatch( + lhs_rank, + rhs_rank, + f"LHS shape: {tuple(lhs_shape)}, RHS shape: {tuple(value.fake_value.shape)}", + ) + elif isinstance(value, (NumericType, LiteralType)): + # Allow scalar assignment to tensor (broadcasts to tensor shape) + pass + else: + raise exc.RequiresTensorInAssignment(value) + return self + + def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: + if origin.is_host(): + warning(exc.TensorOperationInWrapper) + + assert isinstance(self.element_types["tensor_like"], TensorType) + return TensorType( + origin, + self.element_types["tensor_like"] + .proxy() + .new_empty(self._device_indexing_size(key)), + ) + + class SliceType(CollectionType): element_types: slice # pyright: ignore[reportIncompatibleVariableOverride] @@ -1619,7 +1700,7 @@ def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None: if isinstance(lhs, ast.Subscript): # TODO(jansel): test different types of subscript lhs_base_type = self.visit(lhs.value) - if isinstance(lhs_base_type, TensorType): + if isinstance(lhs_base_type, (TensorType, StackTensorType)): self.visit(lhs) # need to populate shape info lhs_base_type = lhs_base_type.propagate_setitem( self.visit(lhs.slice), rhs, self.origin() diff --git a/helion/exc.py b/helion/exc.py index b27f3d03..e958f935 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -138,6 +138,24 @@ class SpecializeArgType(BaseError): message = "hl.specialize() must be called on a size from an input tensor, got: {}" +class StackTensorcOnHost(BaseError): + message = "StackTensor must be created inside the `hl.tile` or `hl.grid` loop." + + +class StackTensorDevPtrOnHost(BaseError): + message = "StackTensor must be created from a dev_ptr tensor defined on device. Use `hl.load` to load a dev_ptrs tensor. " + + +class StackTensorDevPtrDtype(BaseError): + message = ( + "StackTensor must be created from a dev_ptr tensor of dtype int64. Got: {0!s}" + ) + + +class StackTensorExampleOnDevice(BaseError): + message = "hl.stacktensor_like must be called with an example host tensor." + + class FailedToUnpackTupleAssign(BaseError): message = "Failed to unpack values in tuple assignment. Expected a sequence of size {0}, got type: {1!s}." diff --git a/helion/language/__init__.py b/helion/language/__init__.py index 3b8c5946..ad5bc0fa 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -20,6 +20,7 @@ from .scan_ops import cumsum as cumsum from .signal_wait import signal as signal from .signal_wait import wait as wait +from .stack_tensor import stacktensor_like as stacktensor_like from .tile_ops import tile_begin as tile_begin from .tile_ops import tile_block_size as tile_block_size from .tile_ops import tile_end as tile_end @@ -30,3 +31,5 @@ from .tunable_ops import register_reduction_dim as register_reduction_dim from .tunable_ops import register_tunable as register_tunable from .view_ops import subscript as subscript + +_MEMORY_OPS = (store, load, atomic_add, wait, signal) diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index 2bf4f9a8..b936ee92 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -11,6 +11,7 @@ from .._compiler.ast_extension import expr_from_string from .._compiler.indexing_strategy import SubscriptIndexing from . import _decorators +from helion.language.stack_tensor import StackTensor if TYPE_CHECKING: from .._compiler.inductor_lowering import CodegenState @@ -21,7 +22,7 @@ @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def store( - tensor: torch.Tensor, + tensor: torch.Tensor | StackTensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, @@ -33,7 +34,7 @@ def store( based on the hl.tile range. Args: - tensor: The tensor to store to + tensor: The tensor / stack tensor to store to index: The indices to use to index into the tensor value: The value to store extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor @@ -45,24 +46,34 @@ def store( @_decorators.prepare_args(store) def _( - tensor: torch.Tensor, + tensor: torch.Tensor | StackTensor, index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, ) -> tuple[ - torch.Tensor, list[object], torch.Tensor | torch.SymInt | float, torch.Tensor | None + torch.Tensor | tuple, + list[object], + torch.Tensor | torch.SymInt | float, + torch.Tensor | None, ]: from .tile_proxy import Tile if isinstance(value, torch.Tensor) and value.dtype != tensor.dtype: value = value.to(tensor.dtype) index = Tile._tiles_to_sizes(index) - return (tensor, index, value, extra_mask) + + if isinstance(tensor, StackTensor): + return (tuple(tensor), index, value, extra_mask) + + if isinstance(tensor, torch.Tensor): + return (tensor, index, value, extra_mask) + + raise NotImplementedError(f"Cannot store to type: {type(tensor)}") @_decorators.register_fake(store) def _( - tensor: torch.Tensor, + tensor: torch.Tensor | tuple[object, ...], index: list[object], value: torch.Tensor | torch.SymInt | float, extra_mask: torch.Tensor | None = None, @@ -73,17 +84,30 @@ def _( @_decorators.codegen(store) def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) - assert isinstance(tensor, torch.Tensor) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) value = state.ast_arg(2) extra_mask = state.ast_args[3] assert isinstance(extra_mask, (type(None), ast.AST)) - return state.device_function.indexing_strategy.codegen_store( - state, tensor, [*subscript], value, extra_mask - ) + + if isinstance(tensor, torch.Tensor): + return state.device_function.indexing_strategy.codegen_store( + state, tensor, [*subscript], value, extra_mask + ) + if isinstance(tensor, tuple): + from .._compiler.indexing_strategy import StackIndexingStrategy + + stack_tensor_ast = state.ast_args[0] + assert isinstance(stack_tensor_ast, tuple) + assert len(stack_tensor_ast) == 2 + tensor_like_ast, dev_ptrs_ast = stack_tensor_ast + return StackIndexingStrategy.codegen_store( + state, tensor, dev_ptrs_ast, [*subscript], value, extra_mask + ) + raise NotImplementedError(f"Cannot store to type: {type(tensor)}") +# TODO(joydddd): Add support for stack tensor in ref mode. @_decorators.ref(store) def _( tensor: torch.Tensor, @@ -120,7 +144,7 @@ def _( @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def load( - tensor: torch.Tensor, + tensor: torch.Tensor | StackTensor, index: list[object], extra_mask: torch.Tensor | None = None, ) -> torch.Tensor: @@ -131,7 +155,7 @@ def load( based on the hl.tile range. Args: - tensor: The tensor to load from + tensor: The tensor / stack tensor to load from index: The indices to use to index into the tensor extra_mask: The extra mask (beyond automatic tile bounds masking) to apply to the tensor Returns: @@ -140,24 +164,63 @@ def load( raise exc.NotInsideKernel +@_decorators.prepare_args(load) +def _( + tensor: torch.Tensor | StackTensor, + index: list[object], + extra_mask: torch.Tensor | None = None, +) -> tuple[torch.Tensor | tuple, list[object], torch.Tensor | None]: + from .tile_proxy import Tile + + index = Tile._tiles_to_sizes(index) + if isinstance(tensor, StackTensor): + return (tuple(tensor), index, extra_mask) + assert isinstance(tensor, torch.Tensor) + return (tensor, index, extra_mask) + + @_decorators.register_fake(load) def _( - tensor: torch.Tensor, index: list[object], extra_mask: torch.Tensor | None = None + tensor: torch.Tensor | tuple[object, ...], + index: list[object], + extra_mask: torch.Tensor | None = None, ) -> torch.Tensor: - return tensor.new_empty(SubscriptIndexing.compute_shape(tensor, index)) + if isinstance(tensor, torch.Tensor): + target_shape = SubscriptIndexing.compute_shape(tensor, index) + return tensor.new_empty(target_shape) + if isinstance(tensor, tuple): + tensor_like, dev_ptrs = tensor + assert isinstance(tensor_like, torch.Tensor) + assert isinstance(dev_ptrs, torch.Tensor) + tensor_shape = SubscriptIndexing.compute_shape(tensor_like, index) + target_shape = list(dev_ptrs.size()) + tensor_shape + return tensor_like.new_empty(target_shape) + raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}") @_decorators.codegen(load) def _(state: CodegenState) -> ast.AST: tensor = state.proxy_arg(0) - assert isinstance(tensor, torch.Tensor) subscript = state.proxy_arg(1) assert isinstance(subscript, (list, tuple)) extra_mask = state.ast_args[2] assert isinstance(extra_mask, (type(None), ast.AST)) - return state.device_function.indexing_strategy.codegen_load( - state, tensor, [*subscript], extra_mask - ) + + if isinstance(tensor, torch.Tensor): + return state.device_function.indexing_strategy.codegen_load( + state, tensor, [*subscript], extra_mask + ) + if isinstance(tensor, tuple): + from .._compiler.indexing_strategy import StackIndexingStrategy + + stack_tensor_ast = state.ast_args[0] + assert isinstance(stack_tensor_ast, tuple) + assert len(stack_tensor_ast) == 2 + tensor_like_ast, dev_ptrs_ast = stack_tensor_ast + return StackIndexingStrategy.codegen_load( + state, tensor, dev_ptrs_ast, [*subscript], extra_mask + ) + raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}") @_decorators.get_masked_value(load) @@ -165,6 +228,7 @@ def _(node: torch.fx.Node) -> int: return 0 # loads are always masked to 0 +# TODO(joydddd): Add support for stack tensor in ref mode. @_decorators.ref(load) def _( tensor: torch.Tensor, diff --git a/helion/language/stack_tensor.py b/helion/language/stack_tensor.py new file mode 100644 index 00000000..76193494 --- /dev/null +++ b/helion/language/stack_tensor.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import NamedTuple + +import torch + +from .. import exc +from . import _decorators + +if TYPE_CHECKING: + from .._compiler.type_propagation import TypeInfo + from .._compiler.variable_origin import Origin + + +class StackTensor(NamedTuple): + """ + StackTensor is a batch of tensors of the same properties (shape, dtype and stride) + but reside at different memory locations virtually stacked together. + It provides a way to perform parallel memory accesses to multiple tensors with a single subscription. + + **Core Concept:** + Instead of performing separate memory operations on each tensor individually, + StackTensor allows you to broadcast a single memory operation (hl.load, hl.store, hl.atomic_add, + hl.signal, hl.wait etc.) to multiple tensor buffers in parallel. This is particularly useful + for batch processing scenarios where the same operation needs to be applied to multiple tensors. + + **Memory Operation Behavior:** + - **Loads**: When you index into a StackTensor (e.g., `stack_tensor[i]`), + it performs the same indexing operation on all underlying tensor buffers and + returns a new tensor where the results are stacked according to the shape of dev_ptrs. + - **Stores**: When you assign to a StackTensor (e.g., `stack_tensor[i] = value`), + the value tensor is "unstacked" - each slice of the value tensor is written to the respective + underlying tensor buffer. This is the reverse operation of loading. + (e.g. value[j] is writtent to tensor_j[i]). + + **Shape Semantics:** + The StackTensor's shape is `dev_ptrs.shape + tensor_like.shape`, where: + - `dev_ptrs.shape` represents the "batch" dimensions (how tensors are being stacked) + - `tensor_like.shape` represents the shape of each individual tensor + + + Attributes: + tensor_like: A template host tensor that defines the shape, dtype, and other properties + for all tensors in the stack group. + dev_ptrs: A tensor containing device pointers (memory buffer addresses) to the actual + tensors in device memory. Must be of dtype torch.uint64. + + Properties: + dtype: The data type of the tensors in the stack group. Inherited from tensor_like. + shape: The shape of the stacked tensor. Computed as dev_ptrs.shape + tensor_like.shape. + """ + + tensor_like: torch.Tensor + dev_ptrs: torch.Tensor + + @property + def dtype(self) -> torch.dtype: + return self.tensor_like.dtype + + @property + def device(self) -> torch.device: + return self.tensor_like.device + + @property + def shape(self) -> torch.Size: + return self.dev_ptrs.shape + self.tensor_like.shape + + def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride] + self, + index: list[object] | torch.Tensor, + ) -> torch.Tensor: + raise exc.NotInsideKernel + + def __setitem__( # pyright ignore[reportIncompatibleMethodOverride] + self, + index: list[object] | torch.Tensor, + value: torch.Tensor | bool | float, + ) -> None: + raise exc.NotInsideKernel + + # TODO(joydddd): Implement this to support StackTensor in ref mode. + # def as_tuple_of_tensor(self) -> tuple[torch.Tensor, ...]: + """ + Returns a tuple of tensors that represent the underlying buffers of the stack tensor. + + This function is useful when you need to access the underlying tensors directly, + for example, to run in eager mode. + + """ + + +def stacktensor_like( + tensor_like: torch.Tensor, + dev_ptrs: torch.Tensor, +) -> StackTensor: + """ + Creates a StackTensor from a tensor of data pointers (dev_ptrs) pointing to tensors alike + residing at different memory locations. + + This function creates a StackTensor that allows you to broadcast memory operations + to multiple tensor buffers in parallel. + + Must be called inside a helion kernel with dev_ptrs as a device tensor and tensor_like + as a host tensor. + + Args: + tensor_like: A template host tensor that defines the shape, dtype, and other properties + that each buffer in the stack group should have. Must be a host tensor. + dev_ptrs: A tensor containing device pointers (memory addresses) to data buffers. + Must be of dtype torch.uint64 and must be a device tensor. + + Examples: + **Basic Load Operation:** + + .. code-block:: python + + @helion.kernel + def stack_load(dev_ptrs: torch.Tensor, example: torch.Tensor): + for tile in hl.tile(example.size(0)): + ptr_tile = dev_ptrs[:] # Shape: [num_tensors] + stack_tensor = hl.stack_like(example, ptr_tile) + # Load from all tensors simultaneously + data = stack_tensor[tile] # Shape: [num_tensors, tile_size] + return data + + **Store Operation:** + + .. code-block:: python + + @helion.kernel + def stack_store( + dev_ptrs: torch.Tensor, example: torch.Tensor, values: torch.Tensor + ): + ptr_tile = dev_ptrs[:] # Shape: [num_tensors] + stack_tensor = hl.stack_like(example, ptr_tile) + + # Store values of shape [num_tensors, N] to all tensors in parallel + stack_tensor[:] = values # slice values[i, :] goes to tensor i + + **Usage Setup:** + + .. code-block:: python + + # Create list of tensors to process + tensor_list = [torch.randn(16, device="cuda") for _ in range(4)] + tensor_ptrs = torch.as_tensor( + [p.data_ptr() for p in tensor_list], dtype=torch.uint64, device="cuda" + ) + result = stack_load(tensor_ptrs, tensor_list[0]) + + Returns: + A StackTensor object that broadcasts memory operations to all data buffers + pointed to by dev_ptrs. + """ + raise exc.NotInsideKernel + + +@_decorators.device_func_replacement(stacktensor_like) +@_decorators.device_func_replacement(StackTensor) +@_decorators.api(is_device_only=False, allow_host_tensor=True) +def _stack_tensor( + tensor_like: torch.Tensor, + dev_ptrs: torch.Tensor, +) -> StackTensor: + raise exc.NotInsideKernel + + +@_decorators.type_propagation(_stack_tensor) +def _(tensor_like: TypeInfo, dev_ptrs: TypeInfo, *, origin: Origin) -> TypeInfo: + from .._compiler.type_propagation import StackTensorType + from .._compiler.type_propagation import TensorType + + assert isinstance(dev_ptrs, TensorType) + assert isinstance(tensor_like, TensorType) + if origin.is_host(): + raise exc.StackTensorcOnHost + if dev_ptrs.origin.is_host(): + raise exc.StackTensorDevPtrOnHost + if tensor_like.origin.is_device(): + raise exc.StackTensorExampleOnDevice + if dev_ptrs.fake_value.dtype != torch.uint64: + raise exc.StackTensorDevPtrDtype(dev_ptrs.fake_value.dtype) + element_types = { + "dev_ptrs": dev_ptrs, + "tensor_like": tensor_like, + } + + return StackTensorType(origin, element_types) # pyright: ignore[reportArgumentType] + + +@_decorators.register_to_device_ir(_stack_tensor) +def _(tracer: object, tensor_like: torch.Tensor, dev_ptrs: torch.Tensor) -> StackTensor: + return StackTensor(tensor_like, dev_ptrs) diff --git a/test/test_stack_tensor.expected b/test/test_stack_tensor.expected new file mode 100644 index 00000000..d6e67638 --- /dev/null +++ b/test/test_stack_tensor.expected @@ -0,0 +1,222 @@ +This file is automatically generated by assertExpectedJournal calls in test_stack_tensor.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestStackTensor.test_stack_load_2d_dev_ptrs) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _stack_load_kernel_2d_kernel(dev_ptrs, out, dev_ptrs_stride_0, dev_ptrs_stride_1, example_tensor_stride_0, out_stride_0, out_stride_1, out_stride_2, N, M2, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < N + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < M2 + ptr_tile = tl.load(dev_ptrs + (indices_1[:, None] * dev_ptrs_stride_0 + indices_1[None, :] * dev_ptrs_stride_1), mask_1[:, None] & mask_1[None, :], other=0) + load_1 = tl.load(ptr_tile.to(tl.pointer_type(tl.bfloat16))[:, :, None] + (indices_0 * example_tensor_stride_0)[None, None, :], (mask_1[:, None] & mask_1[None, :])[:, :, None] & mask_0[None, None, :], other=0) + tl.store(out + (indices_1[:, None, None] * out_stride_0 + indices_1[None, :, None] * out_stride_1 + indices_0[None, None, :] * out_stride_2), load_1, mask_1[:, None, None] & mask_1[None, :, None] & mask_0[None, None, :]) + +def stack_load_kernel_2d(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _launcher=_default_launcher): + M1, M2 = dev_ptrs.size() + N = example_tensor.size(0) + out = torch.empty(M1, M2, N, dtype=torch.bfloat16, device=dev_ptrs.device) + _BLOCK_SIZE_0 = 4 + _RDIM_SIZE_1 = triton.next_power_of_2(M2) + _launcher(_stack_load_kernel_2d_kernel, (triton.cdiv(N, _BLOCK_SIZE_0),), dev_ptrs, out, dev_ptrs.stride(0), dev_ptrs.stride(1), example_tensor.stride(0), out.stride(0), out.stride(1), out.stride(2), N, M2, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return outfrom __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _stack_load_2d_looped_kernel(dev_ptrs, out, dev_ptrs_stride_0, dev_ptrs_stride_1, example_tensor_stride_0, out_stride_0, out_stride_1, out_stride_2, N, M2, M1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < N + indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_2 < M2 + for offset_1 in tl.range(0, M1.to(tl.int32)): + ptr_tile = tl.load(dev_ptrs + (offset_1 * dev_ptrs_stride_0 + indices_2 * dev_ptrs_stride_1), mask_2, other=0) + load_1 = tl.load(ptr_tile.to(tl.pointer_type(tl.bfloat16))[:, None] + (indices_0 * example_tensor_stride_0)[None, :], mask_2[:, None] & mask_0[None, :], other=0) + tl.store(out + (offset_1 * out_stride_0 + indices_2[:, None] * out_stride_1 + indices_0[None, :] * out_stride_2), load_1, mask_2[:, None] & mask_0[None, :]) + +def stack_load_2d_looped(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _launcher=_default_launcher): + M1, M2 = dev_ptrs.size() + N = example_tensor.size(0) + out = torch.empty(M1, M2, N, dtype=torch.bfloat16, device=dev_ptrs.device) + _BLOCK_SIZE_0 = 4 + _RDIM_SIZE_2 = triton.next_power_of_2(M2) + _launcher(_stack_load_2d_looped_kernel, (triton.cdiv(N, _BLOCK_SIZE_0),), dev_ptrs, out, dev_ptrs.stride(0), dev_ptrs.stride(1), example_tensor.stride(0), out.stride(0), out.stride(1), out.stride(2), N, M2, M1, _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestStackTensor.test_stack_load_2d_tensors) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _stack_load_kernel_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, example_tensor_stride_1, out_stride_0, out_stride_1, out_stride_2, N1, N2, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(N1, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < N1 + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < N2 + indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_2 < M + ptr_tile = tl.load(dev_ptrs + indices_2 * dev_ptrs_stride_0, mask_2, other=0) + load_1 = tl.load(ptr_tile.to(tl.pointer_type(tl.bfloat16))[:, None, None] + (indices_0[:, None] * example_tensor_stride_0 + indices_1[None, :] * example_tensor_stride_1)[None, :, :], mask_2[:, None, None] & (mask_0[:, None] & mask_1[None, :])[None, :, :], other=0) + tl.store(out + (indices_2[:, None, None] * out_stride_0 + indices_0[None, :, None] * out_stride_1 + indices_1[None, None, :] * out_stride_2), load_1, mask_2[:, None, None] & mask_0[None, :, None] & mask_1[None, None, :]) + +def stack_load_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _launcher=_default_launcher): + M = dev_ptrs.size(0) + N1, N2 = example_tensor.size() + out = torch.empty(M, N1, N2, dtype=torch.bfloat16, device=dev_ptrs.device) + _BLOCK_SIZE_0 = 4 + _BLOCK_SIZE_1 = 4 + _RDIM_SIZE_2 = triton.next_power_of_2(M) + _launcher(_stack_load_kernel_kernel, (triton.cdiv(N1, _BLOCK_SIZE_0) * triton.cdiv(N2, _BLOCK_SIZE_1),), dev_ptrs, out, dev_ptrs.stride(0), example_tensor.stride(0), example_tensor.stride(1), out.stride(0), out.stride(1), out.stride(2), N1, N2, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestStackTensor.test_stack_load_grid) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _stack_load_kernel_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, out_stride_0, out_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + ptr_tile = tl.load(dev_ptrs + indices_1 * dev_ptrs_stride_0, None) + load_1 = tl.load(ptr_tile.to(tl.pointer_type(tl.bfloat16))[:] + (offset_0 * example_tensor_stride_0)[None], None) + tl.store(out + (indices_1 * out_stride_0 + offset_0 * out_stride_1), load_1, None) + +def stack_load_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _launcher=_default_launcher): + M = 4 + N = example_tensor.size(0) + out = torch.empty(M, N, dtype=torch.bfloat16, device=dev_ptrs.device) + _RDIM_SIZE_1 = 4 + _launcher(_stack_load_kernel_kernel, (N,), dev_ptrs, out, dev_ptrs.stride(0), example_tensor.stride(0), out.stride(0), out.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestStackTensor.test_stack_mask) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _stack_load_w_mask_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, out_stride_0, out_stride_1, N, M, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < N + indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_3 < M + for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_2 < M + ptr_tile = tl.load(dev_ptrs + indices_2 * dev_ptrs_stride_0, mask_1, other=0) + load_1 = tl.load(ptr_tile.to(tl.pointer_type(tl.bfloat16))[:, None] + (indices_0 * example_tensor_stride_0)[None, :], mask_1[:, None] & mask_0[None, :], other=0) + tl.store(out + (indices_3[:, None] * out_stride_0 + indices_0[None, :] * out_stride_1), load_1, mask_2[:, None] & mask_0[None, :]) + +def stack_load_w_mask(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _launcher=_default_launcher): + M = dev_ptrs.size(0) + N = example_tensor.size(0) + out = torch.empty(M, N, dtype=torch.bfloat16, device=dev_ptrs.device) + _BLOCK_SIZE_0 = 4 + _RDIM_SIZE_2 = triton.next_power_of_2(M) + _BLOCK_SIZE_1 = 4 + _launcher(_stack_load_w_mask_kernel, (triton.cdiv(N, _BLOCK_SIZE_0),), dev_ptrs, out, dev_ptrs.stride(0), example_tensor.stride(0), out.stride(0), out.stride(1), N, M, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestStackTensor.test_stack_store_broadcast_masked) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _stack_store_kernel_kernel(dev_ptrs, x, dev_ptrs_stride_0, example_tensor_stride_0, x_stride_0, N, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < N + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < 3 + ptr_tile = tl.load(dev_ptrs + indices_1 * dev_ptrs_stride_0, mask_1, other=0) + x_tile = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + subscript = x_tile[None, :] + tl.store(ptr_tile.to(tl.pointer_type(tl.bfloat16))[:, None] + (indices_0 * example_tensor_stride_0)[None, :], subscript, mask_1[:, None] & mask_0[None, :]) + +def stack_store_kernel(x: torch.Tensor, dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _launcher=_default_launcher): + N = x.size(0) + _BLOCK_SIZE_0 = 4 + _RDIM_SIZE_1 = 4 + _launcher(_stack_store_kernel_kernel, (triton.cdiv(N, _BLOCK_SIZE_0),), dev_ptrs, x, dev_ptrs.stride(0), example_tensor.stride(0), x.stride(0), N, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + +--- assertExpectedJournal(TestStackTensor.test_stack_store_grid) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _stack_store_kernel_kernel(dev_ptrs, x, dev_ptrs_stride_0, example_tensor_stride_0, x_stride_0, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + ptr_tile = tl.load(dev_ptrs + indices_1 * dev_ptrs_stride_0, None) + load_1 = tl.load(x + offset_0 * x_stride_0, None) + tl.store(ptr_tile.to(tl.pointer_type(tl.bfloat16))[:] + (offset_0 * example_tensor_stride_0)[None], load_1, None) + +def stack_store_kernel(x: torch.Tensor, dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _launcher=_default_launcher): + N = x.size(0) + _RDIM_SIZE_1 = 4 + _launcher(_stack_store_kernel_kernel, (N,), dev_ptrs, x, dev_ptrs.stride(0), example_tensor.stride(0), x.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3) + +--- assertExpectedJournal(TestStackTensor.test_stack_store_scatter) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _stack_store_arange_kernel_kernel(dev_ptrs, dev_ptrs_stride_0, example_tensor_stride_0, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + ptr_tile = tl.load(dev_ptrs + indices_1 * dev_ptrs_stride_0, None) + x = tl.arange(0, 4) + tl.store(ptr_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_tensor_stride_0)[None], x, None) + +def stack_store_arange_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _launcher=_default_launcher): + N = example_tensor.size(0) + _RDIM_SIZE_1 = 4 + _launcher(_stack_store_arange_kernel_kernel, (N,), dev_ptrs, dev_ptrs.stride(0), example_tensor.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3) diff --git a/test/test_stack_tensor.py b/test/test_stack_tensor.py new file mode 100644 index 00000000..6f34c6b8 --- /dev/null +++ b/test/test_stack_tensor.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import unittest + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled +from helion._testing import TestCase +from helion._testing import code_and_output +import helion.language as hl + + +class TestStackTensor(RefEagerTestDisabled, TestCase): + def test_stack_load_grid(self): + @helion.kernel + def stack_load_kernel( + dev_ptrs: torch.Tensor, + example_tensor: torch.Tensor, + ) -> torch.Tensor: + M = hl.specialize(dev_ptrs.size(0)) + N = example_tensor.size(0) + out = torch.empty(M, N, dtype=torch.bfloat16, device=dev_ptrs.device) + + for i in hl.grid(N): + ptr_tile = dev_ptrs[:] + tensors = hl.stacktensor_like(example_tensor, ptr_tile) + out[:, i] = tensors[i] + return out + + tensor_list = [ + torch.randn(4, device=DEVICE, dtype=torch.bfloat16) for _ in range(4) + ] + tensor_ptrs = torch.as_tensor( + [p.data_ptr() for p in tensor_list], device=DEVICE, dtype=torch.uint64 + ) + code, result = code_and_output(stack_load_kernel, (tensor_ptrs, tensor_list[0])) + torch.testing.assert_close(result, torch.stack(tensor_list)) + self.assertExpectedJournal(code) + + def test_stack_load_2d_tensors(self): + @helion.kernel + def stack_load_kernel( + dev_ptrs: torch.Tensor, + example_tensor: torch.Tensor, + ) -> torch.Tensor: + M = dev_ptrs.size(0) + N1, N2 = example_tensor.size() + out = torch.empty(M, N1, N2, dtype=torch.bfloat16, device=dev_ptrs.device) + + for tile1, tile2 in hl.tile([N1, N2]): + ptr_tile = dev_ptrs[:] + tensors = hl.stacktensor_like(example_tensor, ptr_tile) + out[:, tile1, tile2] = tensors[tile1, tile2] + return out + + tensor_list = [ + torch.randn(4, 4, device=DEVICE, dtype=torch.bfloat16) for _ in range(8) + ] + tensor_ptrs = torch.as_tensor( + [p.data_ptr() for p in tensor_list], device=DEVICE, dtype=torch.uint64 + ) + + code, result = code_and_output( + stack_load_kernel, (tensor_ptrs, tensor_list[0]), block_size=[4, 4] + ) + torch.testing.assert_close(result, torch.stack(tensor_list)) + + self.assertExpectedJournal(code) + + def test_stack_load_2d_dev_ptrs(self): + @helion.kernel + def stack_load_kernel_2d( + dev_ptrs: torch.Tensor, + example_tensor: torch.Tensor, + ) -> torch.Tensor: + M1, M2 = dev_ptrs.size() + N = example_tensor.size(0) + out = torch.empty(M1, M2, N, dtype=torch.bfloat16, device=dev_ptrs.device) + + for tile in hl.tile(N, block_size=4): + ptr_tile = dev_ptrs[:, :] + tensors = hl.stacktensor_like(example_tensor, ptr_tile) + out[:, :, tile] = tensors[tile] + return out + + tensor_list = [ + torch.randn(4, device=DEVICE, dtype=torch.bfloat16) for _ in range(16) + ] + tensor_ptrs = torch.as_tensor( + [p.data_ptr() for p in tensor_list], device=DEVICE, dtype=torch.uint64 + ).reshape(4, 4) + + code_batched, result = code_and_output( + stack_load_kernel_2d, (tensor_ptrs, tensor_list[0]) + ) + torch.testing.assert_close(result, torch.stack(tensor_list).reshape(4, 4, -1)) + + @helion.kernel + def stack_load_2d_looped( + dev_ptrs: torch.Tensor, + example_tensor: torch.Tensor, + ) -> torch.Tensor: + M1, M2 = dev_ptrs.size() + N = example_tensor.size(0) + out = torch.empty(M1, M2, N, dtype=torch.bfloat16, device=dev_ptrs.device) + + for tile in hl.tile(N, block_size=4): + for i in range(M1): + ptr_tile = dev_ptrs[i, :] + tensors = hl.stacktensor_like(example_tensor, ptr_tile) + out[i, :, tile] = tensors[tile] + return out + + code_looped, result = code_and_output( + stack_load_2d_looped, (tensor_ptrs, tensor_list[0]) + ) + torch.testing.assert_close(result, torch.stack(tensor_list).reshape(4, 4, -1)) + self.assertExpectedJournal(code_batched + code_looped) + + def test_stack_mask(self): + @helion.kernel + def stack_load_w_mask( + dev_ptrs: torch.Tensor, + example_tensor: torch.Tensor, + ) -> torch.Tensor: + M = dev_ptrs.size(0) + N = example_tensor.size(0) + out = torch.empty(M, N, dtype=torch.bfloat16, device=dev_ptrs.device) + + for tile in hl.tile(N, block_size=4): + for stack_tile in hl.tile(M, block_size=4): + ptr_tile = dev_ptrs[stack_tile] + tensors = hl.stacktensor_like(example_tensor, ptr_tile) + out[:, tile] = tensors[tile] + return out + + tensor_list = [ + torch.randn(15, device=DEVICE, dtype=torch.bfloat16) for _ in range(3) + ] + tensor_ptrs = torch.as_tensor( + [p.data_ptr() for p in tensor_list], device=DEVICE, dtype=torch.uint64 + ) + + code, result = code_and_output(stack_load_w_mask, (tensor_ptrs, tensor_list[0])) + torch.testing.assert_close(result, torch.stack(tensor_list)) + self.assertExpectedJournal(code) + + def test_stack_store_grid(self): + @helion.kernel + def stack_store_kernel( + x: torch.Tensor, + dev_ptrs: torch.Tensor, + example_tensor: torch.Tensor, + ) -> None: + N = x.size(0) + hl.specialize(dev_ptrs.size(0)) + + for i in hl.grid(N): + ptr_tile = dev_ptrs[:] + tensors = hl.stacktensor_like(example_tensor, ptr_tile) + tensors[i] = x[None, i] + + tensor_list = [ + torch.empty(16, device=DEVICE, dtype=torch.bfloat16) for _ in range(4) + ] + tensor_ptrs = torch.as_tensor( + [p.data_ptr() for p in tensor_list], device=DEVICE, dtype=torch.uint64 + ) + + x = torch.randn(16, device=DEVICE, dtype=torch.bfloat16) + code, result = code_and_output( + stack_store_kernel, (x, tensor_ptrs, tensor_list[0]) + ) + + for tensor in tensor_list: + torch.testing.assert_close(tensor, x) + + self.assertExpectedJournal(code) + + def test_stack_store_broadcast_masked(self): + @helion.kernel + def stack_store_kernel( + x: torch.Tensor, + dev_ptrs: torch.Tensor, + example_tensor: torch.Tensor, + ) -> None: + N = x.size(0) + hl.specialize(dev_ptrs.size(0)) + + for tile in hl.tile(N, block_size=4): + ptr_tile = dev_ptrs[:] + tensors = hl.stacktensor_like(example_tensor, ptr_tile) + x_tile = x[tile] + tensors[tile] = x_tile[None, :] + + tensor_list = [ + torch.empty(15, device=DEVICE, dtype=torch.bfloat16) for _ in range(3) + ] + tensor_ptrs = torch.as_tensor( + [p.data_ptr() for p in tensor_list], device=DEVICE, dtype=torch.uint64 + ) + + x = torch.randn(15, device=DEVICE, dtype=torch.bfloat16) + code, result = code_and_output( + stack_store_kernel, (x, tensor_ptrs, tensor_list[0]) + ) + + for tensor in tensor_list: + torch.testing.assert_close(tensor, x) + + self.assertExpectedJournal(code) + + def test_stack_store_scatter(self): + @helion.kernel + def stack_store_arange_kernel( + dev_ptrs: torch.Tensor, + example_tensor: torch.Tensor, + ) -> None: + N = example_tensor.size(0) + M = hl.specialize(dev_ptrs.size(0)) + + for i in hl.grid(N): + ptr_tile = dev_ptrs[:] + tensors = hl.stacktensor_like(example_tensor, ptr_tile) + x = hl.arange(M) + tensors[i] = x + + tensor_list = [ + torch.empty(15, device=DEVICE, dtype=torch.int32) for _ in range(4) + ] + tensor_ptrs = torch.as_tensor( + [p.data_ptr() for p in tensor_list], device=DEVICE, dtype=torch.uint64 + ) + + code, result = code_and_output( + stack_store_arange_kernel, (tensor_ptrs, tensor_list[0]) + ) + + for i, tensor in enumerate(tensor_list): + assert tensor.eq(i).all().item() + + self.assertExpectedJournal(code) + + +if __name__ == "__main__": + unittest.main()