Skip to content

Add stacked tensor #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .type_propagation import LiteralType
from .type_propagation import NumericType
from .type_propagation import SequenceType
from .type_propagation import StackTensorType
from .type_propagation import TensorType
from .type_propagation import TileIndexType
from .type_propagation import TypeInfo
Expand Down Expand Up @@ -321,12 +322,14 @@ def build_rolled_reductions(self) -> None:
graph_to_info = {}
allow_loop = False

# First, check if any graph contains matmul with rdim
# First, check if any graph contains matmul or dev_prts stacking with rdim
# If so, we can't roll any graphs in this reduction dimension
can_roll_graphs = True
for graph_info in self.graphs:
roller = ReductionRoller(self, rdim, {})
if roller.has_matmul_with_rdim(graph_info.graph):
if roller.has_matmul_with_rdim(
graph_info.graph
) or roller.has_stack_tensor_with_rdim(graph_info.graph):
can_roll_graphs = False
break

Expand Down Expand Up @@ -870,7 +873,9 @@ def visit_Assign(self, node: ast.Assign) -> None:
assert isinstance(target.value, ExtendedAST)
assert target.value._type_info is not None
target_origin = target.value._type_info.origin # pyright: ignore[reportOptionalMemberAccess]
if not target_origin.is_host():
if not target_origin.is_host() and not isinstance(
target.value._type_info, StackTensorType
):
# Get the variable name for the error message
var_name = (
target.value.id
Expand All @@ -895,7 +900,9 @@ def _assign_subscript(self, target: ast.Subscript, val: object) -> None:
assert isinstance(target.value, ExtendedAST)
assert target.value._type_info is not None
target_origin = target.value._type_info.origin
assert target_origin.is_host()
assert target_origin.is_host() or isinstance(
target.value._type_info, StackTensorType
)

return hl.store(
self.visit(target.value), # pyright: ignore[reportArgumentType]
Expand Down Expand Up @@ -928,6 +935,8 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
if isinstance(node.slice, ast.Constant):
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
raise exc.InvalidSequenceSubscription(node.slice)
if isinstance(type_info, StackTensorType):
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
if type_info is not None and type_info.origin.is_host():
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
Expand Down
147 changes: 147 additions & 0 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import sympy
import torch
from torch._inductor.utils import triton_type

from .. import exc
from .._compat import get_tensor_descriptor_fn_name
Expand All @@ -19,10 +20,15 @@
from .variable_origin import BlockSizeOrigin

if TYPE_CHECKING:
from collections.abc import Sequence

from ..runtime.config import Config
from .device_function import TensorDescriptorArg
from .inductor_lowering import CodegenState

SymIntLike = torch.SymInt | int
ShapeLike = Sequence[SymIntLike]


class IndexingStrategy:
def codegen_load(
Expand Down Expand Up @@ -296,6 +302,147 @@ def codegen_store(
)


class StackIndexingStrategy:
"""
Generate pointer math for stacking load/store to several device memory pointers sharing the same indexing.

offset, mask are calculated for the tensor_like template tensor and then broadcasted to each dev_ptr
, with the results stacked.

e.g. for a 1D offset tensor and a 1D dev_ptr array, the stack offset is:
stack_offset = dev_ptrs[:, None] + offset[None, :]

"""

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more detail on the semantics of this indexing strategy

def get_broadcast_str(
stack_shape: ShapeLike,
subscript_shape: ShapeLike,
) -> tuple[str, str]:
"""
Args:
stack_shape: shape of the dev_ptr tensor.
subscript_shape: shape of subscription for each individual tensor.

Returns:
the broadcast str for dev_ptrs and individual tensor offset.
"""
stack_broadcast_keys = [":" for _ in stack_shape] + [
"None" for _ in subscript_shape
]
stack_broadcast = f"[{', '.join(stack_broadcast_keys)}]"
tensor_broadcast_keys = ["None" for _ in stack_shape] + [
":" for _ in subscript_shape
]
tensor_broadcast = f"[{', '.join(tensor_broadcast_keys)}]"

return stack_broadcast, tensor_broadcast

@staticmethod
def get_mask_expr(
state: CodegenState,
indexing: SubscriptIndexing,
stack_shape: ShapeLike,
subscript_shape: ShapeLike,
) -> ast.AST | None:
stack_broadcast, tensor_broadcast = StackIndexingStrategy.get_broadcast_str(
stack_shape, subscript_shape
)

mask_exprs = []
dev_ptr_mask_exprs = []
# Generate Mask

for dim, size in enumerate(stack_shape):
if (
index := CompileEnvironment.current().get_block_id(size)
) is not None and (mask_var := state.codegen.mask_var(index)) is not None:
expand = state.tile_strategy.expand_str(stack_shape, dim)
dev_ptr_mask_exprs.append(f"({mask_var}{expand})")

if dev_ptr_mask_exprs:
dev_ptr_mask_expr = f"({'&'.join(dev_ptr_mask_exprs)})"
if len(dev_ptr_mask_exprs) < len(stack_shape):
dev_ptr_mask_expr = f"tl.broadcast_to({dev_ptr_mask_expr}, {state.tile_strategy.shape_str(stack_shape)})"
dev_ptr_mask_expr = f"({dev_ptr_mask_expr}){stack_broadcast}"
mask_exprs.append(dev_ptr_mask_expr)

if indexing.has_mask():
mask_exprs.append(f"(tensor_mask){tensor_broadcast}")
return expr_from_string(
"&".join(mask_exprs), tensor_mask=indexing.mask_expr
)
if mask_exprs:
return expr_from_string("&".join(mask_exprs))
return None

@staticmethod
def codegen_load(
state: CodegenState,
stack_tensor: tuple[torch.Tensor, torch.Tensor],
dev_ptrs_ast: ast.AST,
subscript: list[object],
extra_mask: ast.AST | None,
) -> ast.AST:
tensor_like, dev_ptrs = stack_tensor
indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask)
subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript)
stack_shape = [*dev_ptrs.size()]

mask_expr = StackIndexingStrategy.get_mask_expr(
state, indexing, stack_shape, subscripts_shape
)
extra = ", other=0"
if mask_expr is None:
mask_expr = expr_from_string("None")
extra = ""

stack_broadcast, tensor_broadcast = StackIndexingStrategy.get_broadcast_str(
stack_shape, subscripts_shape
)

dtype = triton_type(tensor_like.dtype)
return expr_from_string(
f"tl.load((base.to(tl.pointer_type({dtype}))){stack_broadcast} + (offset){tensor_broadcast}, mask{extra})",
base=dev_ptrs_ast,
offset=indexing.index_expr,
mask=mask_expr,
)

@staticmethod
def codegen_store(
state: CodegenState,
stack_tensor: tuple[torch.Tensor, torch.Tensor],
dev_ptrs_ast: ast.AST,
subscript: list[object],
value: ast.AST,
extra_mask: ast.AST | None,
) -> ast.AST:
tensor_like, dev_ptrs = stack_tensor
indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask)
subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript)
stack_shape = [*dev_ptrs.size()]

mask_expr = StackIndexingStrategy.get_mask_expr(
state, indexing, stack_shape, subscripts_shape
)
if mask_expr is None:
mask_expr = expr_from_string("None")

stack_broadcast, tensor_broadcast = StackIndexingStrategy.get_broadcast_str(
stack_shape, subscripts_shape
)

dtype = triton_type(tensor_like.dtype)
return expr_from_string(
f"tl.store(base.to(tl.pointer_type({dtype})){stack_broadcast} + (offset){tensor_broadcast}, value, mask)",
base=dev_ptrs_ast,
value=value,
offset=indexing.index_expr,
mask=mask_expr,
)


class SubscriptIndexing(NamedTuple):
index_expr: ast.AST
mask_expr: ast.AST
Expand Down
30 changes: 30 additions & 0 deletions helion/_compiler/roll_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch.fx import map_arg

from ..language import _MEMORY_OPS
from ..language._tracing_ops import _for_loop
from ..language._tracing_ops import _get_symnode
from ..language._tracing_ops import _host_tensor
Expand Down Expand Up @@ -277,6 +278,35 @@ def is_matmul_with_rdim(node: torch.fx.Node) -> bool:

return any(is_matmul_with_rdim(node) for node in graph.nodes)

def has_stack_tensor_with_rdim(self, graph: torch.fx.Graph) -> bool:
"""Check if a graph contains stack tensors with rdim inputs."""

def is_stack_with_rdim(node: torch.fx.Node) -> bool:
"""Check if a node is a stack dev_ptr with rdim inputs."""
if node.op != "call_function":
return False

if node.target not in _MEMORY_OPS:
return False

host_tensor = node.args[0]

if not isinstance(host_tensor, tuple):
return False

# Check if stack dims have rdim
if len(host_tensor) == 2:
assert isinstance(host_tensor[1], torch.fx.Node)
stack = host_tensor[1].meta.get("val", None)
if isinstance(stack, torch.Tensor):
for size in stack.size():
block_idx = CompileEnvironment.current().get_block_id(size)
if block_idx == self.rdim.block_id:
return True
return False

return any(is_stack_with_rdim(node) for node in graph.nodes)

def process(self, graph: torch.fx.Graph) -> torch.fx.Graph:
for node in graph.nodes:
if self.should_go_in_inner_graph(node):
Expand Down
83 changes: 82 additions & 1 deletion helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..autotuner.config_spec import BlockSizeSpec
from ..language._decorators import get_device_func_replacement
from ..language._decorators import is_api_func
from ..language.stack_tensor import StackTensor
from ..language.tile_proxy import Tile
from ..language.tile_proxy import _CheckForIndexCalls
from .ast_extension import ExtendedAST
Expand Down Expand Up @@ -1294,6 +1295,86 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
return self.element_types[attr]


class StackTensorType(ClassType):
element_types: dict[str, TypeInfo] # pyright: ignore[reportIncompatibleVariableOverride]

def proxy(self) -> StackTensor: # pyright: ignore[reportIncompatibleMethodOverride]
with proxy_tensor.disable_proxy_modes_tracing():
fake_mode = torch._C._unset_dispatch_mode( # pyright: ignore[reportAttributeAccessIssue]
torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue]
)
try:
assert isinstance(self.element_types["tensor_like"], TensorType)
assert isinstance(self.element_types["dev_ptrs"], TensorType)
return StackTensor(
self.element_types["tensor_like"].proxy(),
self.element_types["dev_ptrs"].proxy(),
)
finally:
assert fake_mode is not None
torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue]

def merge(self, other: TypeInfo) -> TypeInfo:
if isinstance(other, StackTensorType):
self_elements = self.element_types
other_elements = other.element_types
if set(self_elements.keys()) == set(other_elements.keys()):
return StackTensorType(
origin=other.origin,
element_types={
key: self_elements[key].merge(other_elements[key])
for key in self_elements
},
)
return super().merge(other)

def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
tensor_like_type = self.element_types["tensor_like"]
assert isinstance(tensor_like_type, TensorType)
size_like = tensor_like_type._device_indexing_size(key)

dev_ptrs_type = self.element_types["dev_ptrs"]
assert isinstance(dev_ptrs_type, TensorType)
stack_size = list(dev_ptrs_type.fake_value.size())

return stack_size + size_like

def propagate_setitem(
self, key: TypeInfo, value: TypeInfo, origin: Origin
) -> TypeInfo:
if origin.is_host():
warning(exc.TensorOperationInWrapper)
else:
lhs_shape = self._device_indexing_size(key)
lhs_rank = len(lhs_shape)
if isinstance(value, TensorType):
rhs_rank = value.fake_value.ndim
if lhs_rank != rhs_rank:
raise exc.RankMismatch(
lhs_rank,
rhs_rank,
f"LHS shape: {tuple(lhs_shape)}, RHS shape: {tuple(value.fake_value.shape)}",
)
elif isinstance(value, (NumericType, LiteralType)):
# Allow scalar assignment to tensor (broadcasts to tensor shape)
pass
else:
raise exc.RequiresTensorInAssignment(value)
return self

def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
if origin.is_host():
warning(exc.TensorOperationInWrapper)

assert isinstance(self.element_types["tensor_like"], TensorType)
return TensorType(
origin,
self.element_types["tensor_like"]
.proxy()
.new_empty(self._device_indexing_size(key)),
)


class SliceType(CollectionType):
element_types: slice # pyright: ignore[reportIncompatibleVariableOverride]

Expand Down Expand Up @@ -1619,7 +1700,7 @@ def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None:
if isinstance(lhs, ast.Subscript):
# TODO(jansel): test different types of subscript
lhs_base_type = self.visit(lhs.value)
if isinstance(lhs_base_type, TensorType):
if isinstance(lhs_base_type, (TensorType, StackTensorType)):
self.visit(lhs) # need to populate shape info
lhs_base_type = lhs_base_type.propagate_setitem(
self.visit(lhs.slice), rhs, self.origin()
Expand Down
Loading
Loading