@@ -72,29 +72,6 @@ def gmem_signal_cas_kernel(signal_pad: torch.Tensor, *, _launcher=_default_launc
72
72
_launcher(_gmem_signal_cas_kernel_kernel, (n,), signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
73
73
return signal_pad
74
74
75
- --- assertExpectedJournal(TestWait.test_signal_stack_signalpad)
76
- from __future__ import annotations
77
-
78
- import torch
79
- import helion
80
- import triton
81
- import triton.language as tl
82
- from helion.runtime import default_launcher as _default_launcher
83
-
84
- @triton.jit
85
- def _gmem_signal_pointers_kernel_kernel(signal_pad_ptrs, signal_pad_ptrs_size_0, example_stride_0, signal_pad_ptrs_stride_0, _RDIM_SIZE_1: tl.constexpr):
86
- pid_0 = tl.program_id(0)
87
- offset_0 = pid_0
88
- indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
89
- mask_1 = indices_1 < signal_pad_ptrs_size_0
90
- ptr_tile = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
91
- helion.runtime.triton_send_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_stride_0)[None], update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
92
-
93
- def gmem_signal_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tensor, *, _launcher=_default_launcher):
94
- _RDIM_SIZE_1 = triton.next_power_of_2(signal_pad_ptrs.size(0))
95
- _launcher(_gmem_signal_pointers_kernel_kernel, (example.size(0),), signal_pad_ptrs, signal_pad_ptrs.size(0), example.stride(0), signal_pad_ptrs.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3)
96
- return signal_pad_ptrs
97
-
98
75
--- assertExpectedJournal(TestWait.test_signal_multiple)
99
76
from __future__ import annotations
100
77
@@ -143,6 +120,31 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor, *, _launcher=_defaul
143
120
_launcher(_gmem_signal_tensor_bar_kernel_kernel, (triton.cdiv(n, _BLOCK_SIZE_0),), signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
144
121
return signal_pad
145
122
123
+ --- assertExpectedJournal(TestWait.test_signal_stack_signalpad)
124
+ from __future__ import annotations
125
+
126
+ import torch
127
+ import helion
128
+ import triton
129
+ import triton.language as tl
130
+ from helion.runtime import default_launcher as _default_launcher
131
+
132
+ helion.runtime.set_triton_allocator()
133
+
134
+ @triton.jit
135
+ def _gmem_signal_pointers_kernel_kernel(signal_pad_ptrs, signal_pad_ptrs_size_0, example_stride_0, signal_pad_ptrs_stride_0, _RDIM_SIZE_1: tl.constexpr):
136
+ pid_0 = tl.program_id(0)
137
+ offset_0 = pid_0
138
+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
139
+ mask_1 = indices_1 < signal_pad_ptrs_size_0
140
+ ptr_tile = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
141
+ helion.runtime.triton_send_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_stride_0)[None], update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
142
+
143
+ def gmem_signal_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tensor, *, _launcher=_default_launcher):
144
+ _RDIM_SIZE_1 = triton.next_power_of_2(signal_pad_ptrs.size(0))
145
+ _launcher(_gmem_signal_pointers_kernel_kernel, (example.size(0),), signal_pad_ptrs, signal_pad_ptrs.size(0), example.stride(0), signal_pad_ptrs.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3)
146
+ return signal_pad_ptrs
147
+
146
148
--- assertExpectedJournal(TestWait.test_wait_2d_tile)
147
149
from __future__ import annotations
148
150
@@ -256,31 +258,6 @@ def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor, *, _launcher=_defau
256
258
_launcher(_gmem_wait_multi_bar_kernel_cas_kernel, (triton.cdiv(N, _BLOCK_SIZE_0),), signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
257
259
return signal_pad
258
260
259
- --- assertExpectedJournal(TestWait.test_wait_stack_signalpad)
260
- from __future__ import annotations
261
-
262
- import torch
263
- import helion
264
- import triton
265
- import triton.language as tl
266
- from helion.runtime import default_launcher as _default_launcher
267
-
268
- @triton.jit
269
- def _gmem_wait_pointers_kernel_kernel(signal_pad_ptrs, out, signal_pad_ptrs_size_0, example_stride_0, out_stride_0, signal_pad_ptrs_stride_0, _RDIM_SIZE_1: tl.constexpr):
270
- pid_0 = tl.program_id(0)
271
- offset_0 = pid_0
272
- indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
273
- mask_1 = indices_1 < signal_pad_ptrs_size_0
274
- dev_tile = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
275
- helion.runtime.triton_wait_multiple_signal(addr=dev_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_stride_0)[None], expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
276
- tl.store(out + offset_0 * out_stride_0, offset_0, None)
277
-
278
- def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tensor, *, _launcher=_default_launcher):
279
- out = torch.empty_like(example)
280
- _RDIM_SIZE_1 = triton.next_power_of_2(signal_pad_ptrs.size(0))
281
- _launcher(_gmem_wait_pointers_kernel_kernel, (example.size(0),), signal_pad_ptrs, out, signal_pad_ptrs.size(0), example.stride(0), out.stride(0), signal_pad_ptrs.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3)
282
- return out
283
-
284
261
--- assertExpectedJournal(TestWait.test_wait_pointers)
285
262
from __future__ import annotations
286
263
@@ -311,3 +288,30 @@ def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, pad_shape: hl.const
311
288
_BLOCK_SIZE_1 = N
312
289
_launcher(_gmem_wait_pointers_kernel_kernel, (4,), signal_pad_ptrs, out, out.stride(0), signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
313
290
return out
291
+
292
+ --- assertExpectedJournal(TestWait.test_wait_stack_signalpad)
293
+ from __future__ import annotations
294
+
295
+ import torch
296
+ import helion
297
+ import triton
298
+ import triton.language as tl
299
+ from helion.runtime import default_launcher as _default_launcher
300
+
301
+ helion.runtime.set_triton_allocator()
302
+
303
+ @triton.jit
304
+ def _gmem_wait_pointers_kernel_kernel(signal_pad_ptrs, out, signal_pad_ptrs_size_0, example_stride_0, out_stride_0, signal_pad_ptrs_stride_0, _RDIM_SIZE_1: tl.constexpr):
305
+ pid_0 = tl.program_id(0)
306
+ offset_0 = pid_0
307
+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
308
+ mask_1 = indices_1 < signal_pad_ptrs_size_0
309
+ dev_tile = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
310
+ helion.runtime.triton_wait_multiple_signal(addr=dev_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_stride_0)[None], expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
311
+ tl.store(out + offset_0 * out_stride_0, offset_0, None)
312
+
313
+ def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tensor, *, _launcher=_default_launcher):
314
+ out = torch.empty_like(example)
315
+ _RDIM_SIZE_1 = triton.next_power_of_2(signal_pad_ptrs.size(0))
316
+ _launcher(_gmem_wait_pointers_kernel_kernel, (example.size(0),), signal_pad_ptrs, out, signal_pad_ptrs.size(0), example.stride(0), out.stride(0), signal_pad_ptrs.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3)
317
+ return out
0 commit comments