@@ -5,10 +5,13 @@ Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environmen
5
5
from __future__ import annotations
6
6
7
7
import torch
8
+ import helion
8
9
import triton
9
10
import triton.language as tl
10
11
from helion.runtime import default_launcher as _default_launcher
11
12
13
+ helion.runtime.set_triton_allocator()
14
+
12
15
@triton.jit
13
16
def _stack_load_kernel_2d_kernel(dev_ptrs, out, dev_ptrs_stride_0, dev_ptrs_stride_1, example_tensor_stride_0, out_stride_0, out_stride_1, out_stride_2, N, M2, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
14
17
pid_0 = tl.program_id(0)
@@ -31,10 +34,13 @@ def stack_load_kernel_2d(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *
31
34
return outfrom __future__ import annotations
32
35
33
36
import torch
37
+ import helion
34
38
import triton
35
39
import triton.language as tl
36
40
from helion.runtime import default_launcher as _default_launcher
37
41
42
+ helion.runtime.set_triton_allocator()
43
+
38
44
@triton.jit
39
45
def _stack_load_2d_looped_kernel(dev_ptrs, out, dev_ptrs_stride_0, dev_ptrs_stride_1, example_tensor_stride_0, out_stride_0, out_stride_1, out_stride_2, N, M2, M1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
40
46
pid_0 = tl.program_id(0)
@@ -61,10 +67,13 @@ def stack_load_2d_looped(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *
61
67
from __future__ import annotations
62
68
63
69
import torch
70
+ import helion
64
71
import triton
65
72
import triton.language as tl
66
73
from helion.runtime import default_launcher as _default_launcher
67
74
75
+ helion.runtime.set_triton_allocator()
76
+
68
77
@triton.jit
69
78
def _stack_load_kernel_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, example_tensor_stride_1, out_stride_0, out_stride_1, out_stride_2, N1, N2, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
70
79
num_blocks_0 = tl.cdiv(N1, _BLOCK_SIZE_0)
@@ -96,10 +105,13 @@ def stack_load_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _
96
105
from __future__ import annotations
97
106
98
107
import torch
108
+ import helion
99
109
import triton
100
110
import triton.language as tl
101
111
from helion.runtime import default_launcher as _default_launcher
102
112
113
+ helion.runtime.set_triton_allocator()
114
+
103
115
@triton.jit
104
116
def _stack_load_kernel_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, out_stride_0, out_stride_1, _RDIM_SIZE_1: tl.constexpr):
105
117
pid_0 = tl.program_id(0)
@@ -121,10 +133,13 @@ def stack_load_kernel(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _
121
133
from __future__ import annotations
122
134
123
135
import torch
136
+ import helion
124
137
import triton
125
138
import triton.language as tl
126
139
from helion.runtime import default_launcher as _default_launcher
127
140
141
+ helion.runtime.set_triton_allocator()
142
+
128
143
@triton.jit
129
144
def _stack_load_w_mask_kernel(dev_ptrs, out, dev_ptrs_stride_0, example_tensor_stride_0, out_stride_0, out_stride_1, N, M, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
130
145
pid_0 = tl.program_id(0)
@@ -154,10 +169,13 @@ def stack_load_w_mask(dev_ptrs: torch.Tensor, example_tensor: torch.Tensor, *, _
154
169
from __future__ import annotations
155
170
156
171
import torch
172
+ import helion
157
173
import triton
158
174
import triton.language as tl
159
175
from helion.runtime import default_launcher as _default_launcher
160
176
177
+ helion.runtime.set_triton_allocator()
178
+
161
179
@triton.jit
162
180
def _stack_store_kernel_kernel(dev_ptrs, x, dev_ptrs_stride_0, example_tensor_stride_0, x_stride_0, N, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
163
181
pid_0 = tl.program_id(0)
@@ -181,10 +199,13 @@ def stack_store_kernel(x: torch.Tensor, dev_ptrs: torch.Tensor, example_tensor:
181
199
from __future__ import annotations
182
200
183
201
import torch
202
+ import helion
184
203
import triton
185
204
import triton.language as tl
186
205
from helion.runtime import default_launcher as _default_launcher
187
206
207
+ helion.runtime.set_triton_allocator()
208
+
188
209
@triton.jit
189
210
def _stack_store_kernel_kernel(dev_ptrs, x, dev_ptrs_stride_0, example_tensor_stride_0, x_stride_0, _RDIM_SIZE_1: tl.constexpr):
190
211
pid_0 = tl.program_id(0)
@@ -203,10 +224,13 @@ def stack_store_kernel(x: torch.Tensor, dev_ptrs: torch.Tensor, example_tensor:
203
224
from __future__ import annotations
204
225
205
226
import torch
227
+ import helion
206
228
import triton
207
229
import triton.language as tl
208
230
from helion.runtime import default_launcher as _default_launcher
209
231
232
+ helion.runtime.set_triton_allocator()
233
+
210
234
@triton.jit
211
235
def _stack_store_arange_kernel_kernel(dev_ptrs, dev_ptrs_stride_0, example_tensor_stride_0, _RDIM_SIZE_1: tl.constexpr):
212
236
pid_0 = tl.program_id(0)
0 commit comments