Skip to content

Commit eab7179

Browse files
authored
Fix test/test_stack_tensor.py (#431)
1 parent 8ac85e0 commit eab7179

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

test/test_stack_tensor.expected

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environmen
55
from __future__ import annotations
66

77
import torch
8+
import helion
89
import triton
910
import triton.language as tl
1011
from helion.runtime import default_launcher as _default_launcher
1112

13+
helion.runtime.set_triton_allocator()
14+
1215
@triton.jit
1316
def _stack_load_kernel_2d_kernel(dev_ptrs, out, dev_ptrs_stride_0, dev_ptrs_stride_1, example_tensor_stride_0, out_stride_0, out_stride_1, out_stride_2, N, M2, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
1417
pid_0 = tl.program_id(0)
@@ -31,10 +34,13 @@ def stack_load_kernel_2d(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *
3134
return outfrom __future__ import annotations
3235

3336
import torch
37+
import helion
3438
import triton
3539
import triton.language as tl
3640
from helion.runtime import default_launcher as _default_launcher
3741

42+
helion.runtime.set_triton_allocator()
43+
3844
@triton.jit
3945
def _stack_load_2d_looped_kernel(dev_ptrs, out, dev_ptrs_stride_0, dev_ptrs_stride_1, example_tensor_stride_0, out_stride_0, out_stride_1, out_stride_2, N, M2, M1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
4046
pid_0 = tl.program_id(0)
@@ -61,10 +67,13 @@ def stack_load_2d_looped(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *
6167
from __future__ import annotations
6268

6369
import torch
70+
import helion
6471
import triton
6572
import triton.language as tl
6673
from helion.runtime import default_launcher as _default_launcher
6774

75+
helion.runtime.set_triton_allocator()
76+
6877
@triton.jit
6978
def _stack_load_kernel_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, example_tensor_stride_1, out_stride_0, out_stride_1, out_stride_2, N1, N2, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
7079
num_blocks_0 = tl.cdiv(N1, _BLOCK_SIZE_0)
@@ -96,10 +105,13 @@ def stack_load_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _
96105
from __future__ import annotations
97106

98107
import torch
108+
import helion
99109
import triton
100110
import triton.language as tl
101111
from helion.runtime import default_launcher as _default_launcher
102112

113+
helion.runtime.set_triton_allocator()
114+
103115
@triton.jit
104116
def _stack_load_kernel_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, out_stride_0, out_stride_1, _RDIM_SIZE_1: tl.constexpr):
105117
pid_0 = tl.program_id(0)
@@ -121,10 +133,13 @@ def stack_load_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _
121133
from __future__ import annotations
122134

123135
import torch
136+
import helion
124137
import triton
125138
import triton.language as tl
126139
from helion.runtime import default_launcher as _default_launcher
127140

141+
helion.runtime.set_triton_allocator()
142+
128143
@triton.jit
129144
def _stack_load_w_mask_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, out_stride_0, out_stride_1, N, M, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
130145
pid_0 = tl.program_id(0)
@@ -154,10 +169,13 @@ def stack_load_w_mask(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _
154169
from __future__ import annotations
155170

156171
import torch
172+
import helion
157173
import triton
158174
import triton.language as tl
159175
from helion.runtime import default_launcher as _default_launcher
160176

177+
helion.runtime.set_triton_allocator()
178+
161179
@triton.jit
162180
def _stack_store_kernel_kernel(dev_ptrs, x, dev_ptrs_stride_0, example_tensor_stride_0, x_stride_0, N, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
163181
pid_0 = tl.program_id(0)
@@ -181,10 +199,13 @@ def stack_store_kernel(x: torch.Tensor, dev_ptrs: torch.Tensor, example_tensor:
181199
from __future__ import annotations
182200

183201
import torch
202+
import helion
184203
import triton
185204
import triton.language as tl
186205
from helion.runtime import default_launcher as _default_launcher
187206

207+
helion.runtime.set_triton_allocator()
208+
188209
@triton.jit
189210
def _stack_store_kernel_kernel(dev_ptrs, x, dev_ptrs_stride_0, example_tensor_stride_0, x_stride_0, _RDIM_SIZE_1: tl.constexpr):
190211
pid_0 = tl.program_id(0)
@@ -203,10 +224,13 @@ def stack_store_kernel(x: torch.Tensor, dev_ptrs: torch.Tensor, example_tensor:
203224
from __future__ import annotations
204225

205226
import torch
227+
import helion
206228
import triton
207229
import triton.language as tl
208230
from helion.runtime import default_launcher as _default_launcher
209231

232+
helion.runtime.set_triton_allocator()
233+
210234
@triton.jit
211235
def _stack_store_arange_kernel_kernel(dev_ptrs, dev_ptrs_stride_0, example_tensor_stride_0, _RDIM_SIZE_1: tl.constexpr):
212236
pid_0 = tl.program_id(0)

0 commit comments

Comments
 (0)