17
17
18
18
class StackTensor (NamedTuple ):
19
19
"""
20
- StackTensor is a batch of tensors of the same properties (shape, dtype and stride)
20
+ This class should not be instantiated directly. It is the result of hl.stacktensor_like(...).
21
+ It presents a batch of tensors of the same properties (shape, dtype and stride)
21
22
but reside at different memory locations virtually stacked together.
22
- It provides a way to perform parallel memory accesses to multiple tensors with a single subscription.
23
+
24
+ StackTensor provides a way to perform parallel memory accesses to multiple tensors with a single subscription.
25
+
23
26
24
27
**Core Concept:**
28
+
25
29
Instead of performing separate memory operations on each tensor individually,
26
30
StackTensor allows you to broadcast a single memory operation (hl.load, hl.store, hl.atomic_add,
27
31
hl.signal, hl.wait etc.) to multiple tensor buffers in parallel. This is particularly useful
28
32
for batch processing scenarios where the same operation needs to be applied to multiple tensors.
29
33
30
34
**Memory Operation Behavior:**
35
+
31
36
- **Loads**: When you index into a StackTensor (e.g., `stack_tensor[i]`),
32
37
it performs the same indexing operation on all underlying tensor buffers and
33
38
returns a new tensor where the results are stacked according to the shape of dev_ptrs.
@@ -37,24 +42,27 @@ class StackTensor(NamedTuple):
37
42
(e.g. value[j] is writtent to tensor_j[i]).
38
43
39
44
**Shape Semantics:**
45
+
40
46
The StackTensor's shape is `dev_ptrs.shape + tensor_like.shape`, where:
41
- - `dev_ptrs.shape` represents the "batch" dimensions (how tensors are being stacked)
47
+
48
+ - `dev_ptrs.shape` becomes the stacking dimensions
42
49
- `tensor_like.shape` represents the shape of each individual tensor
43
50
51
+ """
44
52
45
- Attributes:
46
- tensor_like: A template host tensor that defines the shape, dtype, and other properties
53
+ tensor_like : torch .Tensor
54
+ """
55
+ A template host tensor that defines the shape, dtype, and other properties
47
56
for all tensors in the stack group.
48
- dev_ptrs: A tensor containing device pointers (memory buffer addresses) to the actual
49
- tensors in device memory. Must be of dtype torch.uint64.
50
-
51
- Properties:
52
- dtype: The data type of the tensors in the stack group. Inherited from tensor_like.
53
- shape: The shape of the stacked tensor. Computed as dev_ptrs.shape + tensor_like.shape.
57
+ Must be a Host tensor (created outside of the device loop).
54
58
"""
55
59
56
- tensor_like : torch .Tensor
57
60
dev_ptrs : torch .Tensor
61
+ """
62
+ A tensor containing device pointers (memory buffer addresses) to the actual
63
+ tensors in device memory.
64
+ Must be of dtype torch.uint64.
65
+ """
58
66
59
67
@property
60
68
def dtype (self ) -> torch .dtype :
0 commit comments