Skip to content

Commit fa5e632

Browse files
joyddddoulgen
andauthored
Improve Stacktensor Doc (#479)
Co-authored-by: Oguz Ulgen <[email protected]>
1 parent 13f8d31 commit fa5e632

File tree

3 files changed

+33
-12
lines changed

3 files changed

+33
-12
lines changed

docs/api/language.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,18 @@ The `Tile` class represents a portion of an iteration space with the following k
134134
.. autofunction:: subscript
135135
```
136136

137+
## StackTensor
138+
### StackTensor class
139+
```{eval-rst}
140+
.. autoclass:: StackTensor
141+
:undoc-members:
142+
```
143+
144+
### stacktensor_like
145+
```{eval-rst}
146+
.. autofunction:: stacktensor_like
147+
```
148+
137149
## Reduction Operations
138150

139151
### reduce()

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .scan_ops import cumsum as cumsum
2121
from .signal_wait import signal as signal
2222
from .signal_wait import wait as wait
23+
from .stack_tensor import StackTensor as StackTensor
2324
from .stack_tensor import stacktensor_like as stacktensor_like
2425
from .tile_ops import tile_begin as tile_begin
2526
from .tile_ops import tile_block_size as tile_block_size

helion/language/stack_tensor.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,22 @@
1717

1818
class StackTensor(NamedTuple):
1919
"""
20-
StackTensor is a batch of tensors of the same properties (shape, dtype and stride)
20+
This class should not be instantiated directly. It is the result of hl.stacktensor_like(...).
21+
It presents a batch of tensors of the same properties (shape, dtype and stride)
2122
but reside at different memory locations virtually stacked together.
22-
It provides a way to perform parallel memory accesses to multiple tensors with a single subscription.
23+
24+
StackTensor provides a way to perform parallel memory accesses to multiple tensors with a single subscription.
25+
2326
2427
**Core Concept:**
28+
2529
Instead of performing separate memory operations on each tensor individually,
2630
StackTensor allows you to broadcast a single memory operation (hl.load, hl.store, hl.atomic_add,
2731
hl.signal, hl.wait etc.) to multiple tensor buffers in parallel. This is particularly useful
2832
for batch processing scenarios where the same operation needs to be applied to multiple tensors.
2933
3034
**Memory Operation Behavior:**
35+
3136
- **Loads**: When you index into a StackTensor (e.g., `stack_tensor[i]`),
3237
it performs the same indexing operation on all underlying tensor buffers and
3338
returns a new tensor where the results are stacked according to the shape of dev_ptrs.
@@ -37,24 +42,27 @@ class StackTensor(NamedTuple):
3742
(e.g. value[j] is writtent to tensor_j[i]).
3843
3944
**Shape Semantics:**
45+
4046
The StackTensor's shape is `dev_ptrs.shape + tensor_like.shape`, where:
41-
- `dev_ptrs.shape` represents the "batch" dimensions (how tensors are being stacked)
47+
48+
- `dev_ptrs.shape` becomes the stacking dimensions
4249
- `tensor_like.shape` represents the shape of each individual tensor
4350
51+
"""
4452

45-
Attributes:
46-
tensor_like: A template host tensor that defines the shape, dtype, and other properties
53+
tensor_like: torch.Tensor
54+
"""
55+
A template host tensor that defines the shape, dtype, and other properties
4756
for all tensors in the stack group.
48-
dev_ptrs: A tensor containing device pointers (memory buffer addresses) to the actual
49-
tensors in device memory. Must be of dtype torch.uint64.
50-
51-
Properties:
52-
dtype: The data type of the tensors in the stack group. Inherited from tensor_like.
53-
shape: The shape of the stacked tensor. Computed as dev_ptrs.shape + tensor_like.shape.
57+
Must be a Host tensor (created outside of the device loop).
5458
"""
5559

56-
tensor_like: torch.Tensor
5760
dev_ptrs: torch.Tensor
61+
"""
62+
A tensor containing device pointers (memory buffer addresses) to the actual
63+
tensors in device memory.
64+
Must be of dtype torch.uint64.
65+
"""
5866

5967
@property
6068
def dtype(self) -> torch.dtype:

0 commit comments

Comments
 (0)