Skip to content

Commit 0bcfcca

Browse files
committed
Add multicast tensor
stack-info: PR: #346, branch: joydddd/stack/17
1 parent 14110be commit 0bcfcca

File tree

9 files changed

+708
-19
lines changed

9 files changed

+708
-19
lines changed

helion/_compiler/device_ir.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .type_propagation import GridIndexType
5353
from .type_propagation import IterType
5454
from .type_propagation import LiteralType
55+
from .type_propagation import MulticastTensorType
5556
from .type_propagation import NumericType
5657
from .type_propagation import SequenceType
5758
from .type_propagation import TensorType
@@ -781,7 +782,9 @@ def visit_Assign(self, node: ast.Assign) -> None:
781782
assert isinstance(target.value, ExtendedAST)
782783
assert target.value._type_info is not None
783784
target_origin = target.value._type_info.origin # pyright: ignore[reportOptionalMemberAccess]
784-
if not target_origin.is_host():
785+
if not target_origin.is_host() and not isinstance(
786+
target.value._type_info, MulticastTensorType
787+
):
785788
# Get the variable name for the error message
786789
var_name = (
787790
target.value.id
@@ -806,7 +809,9 @@ def _assign_subscript(self, target: ast.Subscript, val: object) -> None:
806809
assert isinstance(target.value, ExtendedAST)
807810
assert target.value._type_info is not None
808811
target_origin = target.value._type_info.origin
809-
assert target_origin.is_host()
812+
assert target_origin.is_host() or isinstance(
813+
target.value._type_info, MulticastTensorType
814+
)
810815

811816
return hl.store(
812817
self.visit(target.value), # pyright: ignore[reportArgumentType]
@@ -839,6 +844,8 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
839844
if isinstance(node.slice, ast.Constant):
840845
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
841846
raise exc.InvalidSequenceSubscription(node.slice)
847+
if type_info is not None and isinstance(type_info, MulticastTensorType):
848+
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
842849
if type_info is not None and type_info.origin.is_host():
843850
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
844851
return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]

helion/_compiler/indexing_strategy.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import sympy
1010
import torch
11+
from torch._inductor.utils import triton_type
1112

1213
from .. import exc
1314
from .ast_extension import expr_from_string
@@ -18,10 +19,15 @@
1819
from .variable_origin import BlockSizeOrigin
1920

2021
if TYPE_CHECKING:
22+
from collections.abc import Sequence
23+
2124
from ..runtime.config import Config
2225
from .device_function import TensorDescriptorArg
2326
from .inductor_lowering import CodegenState
2427

28+
SymIntLike = torch.SymInt | int
29+
ShapeLike = Sequence[SymIntLike]
30+
2531

2632
class IndexingStrategy:
2733
def codegen_load(
@@ -275,6 +281,134 @@ def codegen_store(
275281
)
276282

277283

284+
class MulticastIndexingStrategy:
285+
@staticmethod
286+
def get_broadcast_str(
287+
multicast_shape: ShapeLike,
288+
subscript_shape: ShapeLike,
289+
) -> tuple[str, str]:
290+
multicast_broadcast_keys = [":" for _ in multicast_shape] + [
291+
"None" for _ in subscript_shape
292+
]
293+
multicast_broadcast = f"[{', '.join(multicast_broadcast_keys)}]"
294+
tensor_broadcast_keys = ["None" for _ in multicast_shape] + [
295+
":" for _ in subscript_shape
296+
]
297+
tensor_broadcast = f"[{', '.join(tensor_broadcast_keys)}]"
298+
299+
return multicast_broadcast, tensor_broadcast
300+
301+
@staticmethod
302+
def get_mask_expr(
303+
state: CodegenState,
304+
indexing: SubscriptIndexing,
305+
multicast_shape: ShapeLike,
306+
subscript_shape: ShapeLike,
307+
) -> ast.AST | None:
308+
multicast_broadcast, tensor_broadcast = (
309+
MulticastIndexingStrategy.get_broadcast_str(
310+
multicast_shape, subscript_shape
311+
)
312+
)
313+
314+
mask_exprs = []
315+
dev_ptr_mask_exprs = []
316+
# Generate Mask
317+
318+
for dim, size in enumerate(multicast_shape):
319+
if (
320+
index := CompileEnvironment.current().get_block_id(size)
321+
) is not None and (mask_var := state.codegen.mask_var(index)) is not None:
322+
expand = state.tile_strategy.expand_str(multicast_shape, dim)
323+
dev_ptr_mask_exprs.append(f"({mask_var}{expand})")
324+
325+
if dev_ptr_mask_exprs:
326+
dev_ptr_mask_expr = f"({'&'.join(dev_ptr_mask_exprs)})"
327+
if len(dev_ptr_mask_exprs) < len(multicast_shape):
328+
dev_ptr_mask_expr = f"tl.broadcast_to({dev_ptr_mask_expr}, {state.tile_strategy.shape_str(multicast_shape)})"
329+
dev_ptr_mask_expr = f"({dev_ptr_mask_expr}){multicast_broadcast}"
330+
mask_exprs.append(dev_ptr_mask_expr)
331+
332+
if indexing.has_mask():
333+
mask_exprs.append(f"(tensor_mask){tensor_broadcast}")
334+
return expr_from_string(
335+
"&".join(mask_exprs), tensor_mask=indexing.mask_expr
336+
)
337+
if mask_exprs:
338+
return expr_from_string("&".join(mask_exprs))
339+
return None
340+
341+
@staticmethod
342+
def codegen_load(
343+
state: CodegenState,
344+
tensors: tuple[torch.Tensor, torch.Tensor],
345+
dev_ptrs_ast: ast.AST,
346+
subscript: list[object],
347+
extra_mask: ast.AST | None,
348+
) -> ast.AST:
349+
tensor_like, dev_ptrs = tensors
350+
indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask)
351+
subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript)
352+
multicast_shape = [*dev_ptrs.size()]
353+
354+
mask_expr = MulticastIndexingStrategy.get_mask_expr(
355+
state, indexing, multicast_shape, subscripts_shape
356+
)
357+
extra = ", other=0"
358+
if mask_expr is None:
359+
mask_expr = expr_from_string("None")
360+
extra = ""
361+
362+
multicast_broadcast, tensor_broadcast = (
363+
MulticastIndexingStrategy.get_broadcast_str(
364+
multicast_shape, subscripts_shape
365+
)
366+
)
367+
368+
dtype = triton_type(tensor_like.dtype)
369+
return expr_from_string(
370+
f"tl.load((base.to(tl.pointer_type({dtype}))){multicast_broadcast} + (offset){tensor_broadcast}, mask{extra})",
371+
base=dev_ptrs_ast,
372+
offset=indexing.index_expr,
373+
mask=mask_expr,
374+
)
375+
376+
@staticmethod
377+
def codegen_store(
378+
state: CodegenState,
379+
tensors: tuple[torch.Tensor, torch.Tensor],
380+
dev_ptrs_ast: ast.AST,
381+
subscript: list[object],
382+
value: ast.AST,
383+
extra_mask: ast.AST | None,
384+
) -> ast.AST:
385+
tensor_like, dev_ptrs = tensors
386+
indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask)
387+
subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript)
388+
multicast_shape = [*dev_ptrs.size()]
389+
390+
mask_expr = MulticastIndexingStrategy.get_mask_expr(
391+
state, indexing, multicast_shape, subscripts_shape
392+
)
393+
if mask_expr is None:
394+
mask_expr = expr_from_string("None")
395+
396+
multicast_broadcast, tensor_broadcast = (
397+
MulticastIndexingStrategy.get_broadcast_str(
398+
multicast_shape, subscripts_shape
399+
)
400+
)
401+
402+
dtype = triton_type(tensor_like.dtype)
403+
return expr_from_string(
404+
f"tl.store(base.to(tl.pointer_type({dtype})){multicast_broadcast} + (offset){tensor_broadcast}, value, mask)",
405+
base=dev_ptrs_ast,
406+
value=value,
407+
offset=indexing.index_expr,
408+
mask=mask_expr,
409+
)
410+
411+
278412
class SubscriptIndexing(NamedTuple):
279413
index_expr: ast.AST
280414
mask_expr: ast.AST

helion/_compiler/type_propagation.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..autotuner.config_spec import BlockSizeSpec
2828
from ..language._decorators import get_device_func_replacement
2929
from ..language._decorators import is_api_func
30+
from ..language.multicast_tensor import MulticastTensor
3031
from ..language.tile_proxy import Tile
3132
from ..language.tile_proxy import _CheckForIndexCalls
3233
from .ast_extension import ExtendedAST
@@ -1289,6 +1290,86 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
12891290
return self.element_types[attr]
12901291

12911292

1293+
class MulticastTensorType(ClassType):
1294+
element_types: dict[str, TypeInfo] # pyright: ignore[reportIncompatibleVariableOverride]
1295+
1296+
def proxy(self) -> MulticastTensor: # pyright: ignore[reportIncompatibleMethodOverride]
1297+
with proxy_tensor.disable_proxy_modes_tracing():
1298+
fake_mode = torch._C._unset_dispatch_mode( # pyright: ignore[reportAttributeAccessIssue]
1299+
torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue]
1300+
)
1301+
try:
1302+
assert isinstance(self.element_types["tensor_like"], TensorType)
1303+
assert isinstance(self.element_types["dev_ptrs"], TensorType)
1304+
return MulticastTensor(
1305+
self.element_types["tensor_like"].proxy(),
1306+
self.element_types["dev_ptrs"].proxy(),
1307+
)
1308+
finally:
1309+
assert fake_mode is not None
1310+
torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue]
1311+
1312+
def merge(self, other: TypeInfo) -> TypeInfo:
1313+
if isinstance(other, MulticastTensorType):
1314+
self_elements = self.element_types
1315+
other_elements = other.element_types
1316+
if set(self_elements.keys()) == set(other_elements.keys()):
1317+
return MulticastTensorType(
1318+
origin=other.origin,
1319+
element_types={
1320+
key: self_elements[key].merge(other_elements[key])
1321+
for key in self_elements
1322+
},
1323+
)
1324+
return super().merge(other)
1325+
1326+
def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
1327+
tensor_like_type = self.element_types["tensor_like"]
1328+
assert isinstance(tensor_like_type, TensorType)
1329+
size_like = tensor_like_type._device_indexing_size(key)
1330+
1331+
dev_ptrs_type = self.element_types["dev_ptrs"]
1332+
assert isinstance(dev_ptrs_type, TensorType)
1333+
multicast_size = list(dev_ptrs_type.fake_value.size())
1334+
1335+
return multicast_size + size_like
1336+
1337+
def propagate_setitem(
1338+
self, key: TypeInfo, value: TypeInfo, origin: Origin
1339+
) -> TypeInfo:
1340+
if origin.is_host():
1341+
warning(exc.TensorOperationInWrapper)
1342+
else:
1343+
lhs_shape = self._device_indexing_size(key)
1344+
lhs_rank = len(lhs_shape)
1345+
if isinstance(value, TensorType):
1346+
rhs_rank = value.fake_value.ndim
1347+
if lhs_rank != rhs_rank:
1348+
raise exc.RankMismatch(
1349+
lhs_rank,
1350+
rhs_rank,
1351+
f"LHS shape: {tuple(lhs_shape)}, RHS shape: {tuple(value.fake_value.shape)}",
1352+
)
1353+
elif isinstance(value, (NumericType, LiteralType)):
1354+
# Allow scalar assignment to tensor (broadcasts to tensor shape)
1355+
pass
1356+
else:
1357+
raise exc.RequiresTensorInAssignment(value)
1358+
return self
1359+
1360+
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1361+
if origin.is_host():
1362+
warning(exc.TensorOperationInWrapper)
1363+
1364+
assert isinstance(self.element_types["tensor_like"], TensorType)
1365+
return TensorType(
1366+
origin,
1367+
self.element_types["tensor_like"]
1368+
.proxy()
1369+
.new_empty(self._device_indexing_size(key)),
1370+
)
1371+
1372+
12921373
class SliceType(CollectionType):
12931374
element_types: slice # pyright: ignore[reportIncompatibleVariableOverride]
12941375

@@ -1614,7 +1695,7 @@ def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None:
16141695
if isinstance(lhs, ast.Subscript):
16151696
# TODO(jansel): test different types of subscript
16161697
lhs_base_type = self.visit(lhs.value)
1617-
if isinstance(lhs_base_type, TensorType):
1698+
if isinstance(lhs_base_type, (TensorType, MulticastTensorType)):
16181699
self.visit(lhs) # need to populate shape info
16191700
lhs_base_type = lhs_base_type.propagate_setitem(
16201701
self.visit(lhs.slice), rhs, self.origin()

helion/exc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,24 @@ class SpecializeArgType(BaseError):
138138
message = "hl.specialize() must be called on a size from an input tensor, got: {}"
139139

140140

141+
class MulticastTensorcOnHost(BaseError):
142+
message = (
143+
"hl.multicast_tensor must be called inside the `hl.tile` or `hl.grid` loop."
144+
)
145+
146+
147+
class MulticastTensorDevPtrOnHost(BaseError):
148+
message = "hl.multicast_tensor must be called with a dev_ptr tensor defined on device. Use `hl.load` to load a dev_ptrs tensor. "
149+
150+
151+
class MulticastTensorDevPtrDtype(BaseError):
152+
message = "hl.multicast_tensor must be called with a dev_ptr tensor of dtype int64. Got: {0!s}"
153+
154+
155+
class MulticastTensorExampleOnDevice(BaseError):
156+
message = "hl.multicast_tensor must be called with an example host tensor."
157+
158+
141159
class FailedToUnpackTupleAssign(BaseError):
142160
message = "Failed to unpack values in tuple assignment. Expected a sequence of size {0}, got type: {1!s}."
143161

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .memory_ops import atomic_add as atomic_add
1414
from .memory_ops import load as load
1515
from .memory_ops import store as store
16+
from .multicast_tensor import multicast_like as multicast_like
1617
from .reduce_ops import reduce as reduce
1718
from .scan_ops import associative_scan as associative_scan
1819
from .scan_ops import cumprod as cumprod

0 commit comments

Comments
 (0)