Skip to content

Commit 8ac85e0

Browse files
authored
Fix: test/test_signal_pad (#432)
1 parent 0211a67 commit 8ac85e0

File tree

1 file changed

+52
-48
lines changed

1 file changed

+52
-48
lines changed

test/test_signal_wait.expected

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -72,29 +72,6 @@ def gmem_signal_cas_kernel(signal_pad: torch.Tensor, *, _launcher=_default_launc
7272
_launcher(_gmem_signal_cas_kernel_kernel, (n,), signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
7373
return signal_pad
7474

75-
--- assertExpectedJournal(TestWait.test_signal_stack_signalpad)
76-
from __future__ import annotations
77-
78-
import torch
79-
import helion
80-
import triton
81-
import triton.language as tl
82-
from helion.runtime import default_launcher as _default_launcher
83-
84-
@triton.jit
85-
def _gmem_signal_pointers_kernel_kernel(signal_pad_ptrs, signal_pad_ptrs_size_0, example_stride_0, signal_pad_ptrs_stride_0, _RDIM_SIZE_1: tl.constexpr):
86-
pid_0 = tl.program_id(0)
87-
offset_0 = pid_0
88-
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
89-
mask_1 = indices_1 < signal_pad_ptrs_size_0
90-
ptr_tile = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
91-
helion.runtime.triton_send_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_stride_0)[None], update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
92-
93-
def gmem_signal_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tensor, *, _launcher=_default_launcher):
94-
_RDIM_SIZE_1 = triton.next_power_of_2(signal_pad_ptrs.size(0))
95-
_launcher(_gmem_signal_pointers_kernel_kernel, (example.size(0),), signal_pad_ptrs, signal_pad_ptrs.size(0), example.stride(0), signal_pad_ptrs.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3)
96-
return signal_pad_ptrs
97-
9875
--- assertExpectedJournal(TestWait.test_signal_multiple)
9976
from __future__ import annotations
10077

@@ -143,6 +120,31 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor, *, _launcher=_defaul
143120
_launcher(_gmem_signal_tensor_bar_kernel_kernel, (triton.cdiv(n, _BLOCK_SIZE_0),), signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
144121
return signal_pad
145122

123+
--- assertExpectedJournal(TestWait.test_signal_stack_signalpad)
124+
from __future__ import annotations
125+
126+
import torch
127+
import helion
128+
import triton
129+
import triton.language as tl
130+
from helion.runtime import default_launcher as _default_launcher
131+
132+
helion.runtime.set_triton_allocator()
133+
134+
@triton.jit
135+
def _gmem_signal_pointers_kernel_kernel(signal_pad_ptrs, signal_pad_ptrs_size_0, example_stride_0, signal_pad_ptrs_stride_0, _RDIM_SIZE_1: tl.constexpr):
136+
pid_0 = tl.program_id(0)
137+
offset_0 = pid_0
138+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
139+
mask_1 = indices_1 < signal_pad_ptrs_size_0
140+
ptr_tile = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
141+
helion.runtime.triton_send_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_stride_0)[None], update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
142+
143+
def gmem_signal_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tensor, *, _launcher=_default_launcher):
144+
_RDIM_SIZE_1 = triton.next_power_of_2(signal_pad_ptrs.size(0))
145+
_launcher(_gmem_signal_pointers_kernel_kernel, (example.size(0),), signal_pad_ptrs, signal_pad_ptrs.size(0), example.stride(0), signal_pad_ptrs.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3)
146+
return signal_pad_ptrs
147+
146148
--- assertExpectedJournal(TestWait.test_wait_2d_tile)
147149
from __future__ import annotations
148150

@@ -256,31 +258,6 @@ def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor, *, _launcher=_defau
256258
_launcher(_gmem_wait_multi_bar_kernel_cas_kernel, (triton.cdiv(N, _BLOCK_SIZE_0),), signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
257259
return signal_pad
258260

259-
--- assertExpectedJournal(TestWait.test_wait_stack_signalpad)
260-
from __future__ import annotations
261-
262-
import torch
263-
import helion
264-
import triton
265-
import triton.language as tl
266-
from helion.runtime import default_launcher as _default_launcher
267-
268-
@triton.jit
269-
def _gmem_wait_pointers_kernel_kernel(signal_pad_ptrs, out, signal_pad_ptrs_size_0, example_stride_0, out_stride_0, signal_pad_ptrs_stride_0, _RDIM_SIZE_1: tl.constexpr):
270-
pid_0 = tl.program_id(0)
271-
offset_0 = pid_0
272-
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
273-
mask_1 = indices_1 < signal_pad_ptrs_size_0
274-
dev_tile = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
275-
helion.runtime.triton_wait_multiple_signal(addr=dev_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_stride_0)[None], expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
276-
tl.store(out + offset_0 * out_stride_0, offset_0, None)
277-
278-
def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tensor, *, _launcher=_default_launcher):
279-
out = torch.empty_like(example)
280-
_RDIM_SIZE_1 = triton.next_power_of_2(signal_pad_ptrs.size(0))
281-
_launcher(_gmem_wait_pointers_kernel_kernel, (example.size(0),), signal_pad_ptrs, out, signal_pad_ptrs.size(0), example.stride(0), out.stride(0), signal_pad_ptrs.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3)
282-
return out
283-
284261
--- assertExpectedJournal(TestWait.test_wait_pointers)
285262
from __future__ import annotations
286263

@@ -311,3 +288,30 @@ def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, pad_shape: hl.const
311288
_BLOCK_SIZE_1 = N
312289
_launcher(_gmem_wait_pointers_kernel_kernel, (4,), signal_pad_ptrs, out, out.stride(0), signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
313290
return out
291+
292+
--- assertExpectedJournal(TestWait.test_wait_stack_signalpad)
293+
from __future__ import annotations
294+
295+
import torch
296+
import helion
297+
import triton
298+
import triton.language as tl
299+
from helion.runtime import default_launcher as _default_launcher
300+
301+
helion.runtime.set_triton_allocator()
302+
303+
@triton.jit
304+
def _gmem_wait_pointers_kernel_kernel(signal_pad_ptrs, out, signal_pad_ptrs_size_0, example_stride_0, out_stride_0, signal_pad_ptrs_stride_0, _RDIM_SIZE_1: tl.constexpr):
305+
pid_0 = tl.program_id(0)
306+
offset_0 = pid_0
307+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
308+
mask_1 = indices_1 < signal_pad_ptrs_size_0
309+
dev_tile = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
310+
helion.runtime.triton_wait_multiple_signal(addr=dev_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_stride_0)[None], expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
311+
tl.store(out + offset_0 * out_stride_0, offset_0, None)
312+
313+
def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tensor, *, _launcher=_default_launcher):
314+
out = torch.empty_like(example)
315+
_RDIM_SIZE_1 = triton.next_power_of_2(signal_pad_ptrs.size(0))
316+
_launcher(_gmem_wait_pointers_kernel_kernel, (example.size(0),), signal_pad_ptrs, out, signal_pad_ptrs.size(0), example.stride(0), out.stride(0), signal_pad_ptrs.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3)
317+
return out

0 commit comments

Comments
 (0)