@@ -58,6 +58,145 @@ def _arange_three_args_step_make_precompiler(x: torch.Tensor):
58
58
from helion.runtime.precompile_shim import make_precompiler
59
59
return make_precompiler(_arange_three_args_step_kernel)(out, out.size(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
60
60
61
+ --- assertExpectedJournal(TestIndexing.test_broadcasting_block_ptr_indexing)
62
+ from __future__ import annotations
63
+
64
+ import torch
65
+ import triton
66
+ import triton.language as tl
67
+
68
+ @triton.jit
69
+ def _broadcast_add_3d_kernel(x, bias1, bias2, out, bias1_size_1, bias1_size_2, bias2_size_0, bias2_size_2, out_size_0, out_size_1, out_size_2, x_size_0, x_size_1, x_size_2, bias1_stride_0, bias1_stride_1, bias1_stride_2, bias2_stride_0, bias2_stride_1, bias2_stride_2, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, d0, d1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
70
+ num_blocks_0 = tl.cdiv(d0, _BLOCK_SIZE_0)
71
+ num_blocks_1 = tl.cdiv(d1, _BLOCK_SIZE_1)
72
+ pid_0 = tl.program_id(0) % num_blocks_0
73
+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
74
+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
75
+ offset_0 = pid_0 * _BLOCK_SIZE_0
76
+ offset_1 = pid_1 * _BLOCK_SIZE_1
77
+ offset_2 = pid_2 * _BLOCK_SIZE_2
78
+ load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1, x_size_2], [x_stride_0, x_stride_1, x_stride_2], [offset_0, offset_1, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
79
+ load_1 = tl.load(tl.make_block_ptr(bias1, [1, bias1_size_1, bias1_size_2], [bias1_stride_0, bias1_stride_1, bias1_stride_2], [0, offset_1, offset_2], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), boundary_check=[1, 2], padding_option='zero')
80
+ v_0 = load + load_1
81
+ load_2 = tl.load(tl.make_block_ptr(bias2, [bias2_size_0, 1, bias2_size_2], [bias2_stride_0, bias2_stride_1, bias2_stride_2], [offset_0, 0, offset_2], [_BLOCK_SIZE_0, 1, _BLOCK_SIZE_2], [2, 1, 0]), boundary_check=[0, 2], padding_option='zero')
82
+ v_1 = v_0 + load_2
83
+ tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, out_size_2], [out_stride_0, out_stride_1, out_stride_2], [offset_0, offset_1, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), v_1, boundary_check=[0, 1, 2])
84
+
85
+ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
86
+ d0, d1, d2 = x.size()
87
+ out = torch.empty_like(x)
88
+ _BLOCK_SIZE_0 = 8
89
+ _BLOCK_SIZE_1 = 8
90
+ _BLOCK_SIZE_2 = 8
91
+ _broadcast_add_3d_kernel[triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),](x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
92
+ return out
93
+
94
+ def _broadcast_add_3d_make_precompiler(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
95
+ d0, d1, d2 = x.size()
96
+ out = torch.empty_like(x)
97
+ _BLOCK_SIZE_0 = 8
98
+ _BLOCK_SIZE_1 = 8
99
+ _BLOCK_SIZE_2 = 8
100
+ from helion.runtime.precompile_shim import make_precompiler
101
+ return make_precompiler(_broadcast_add_3d_kernel)(x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
102
+
103
+ --- assertExpectedJournal(TestIndexing.test_broadcasting_pointer_indexing)
104
+ from __future__ import annotations
105
+
106
+ import torch
107
+ import triton
108
+ import triton.language as tl
109
+
110
+ @triton.jit
111
+ def _broadcast_add_3d_kernel(x, bias1, bias2, out, bias1_stride_1, bias1_stride_2, bias2_stride_0, bias2_stride_2, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, d0, d1, d2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
112
+ num_blocks_0 = tl.cdiv(d0, _BLOCK_SIZE_0)
113
+ num_blocks_1 = tl.cdiv(d1, _BLOCK_SIZE_1)
114
+ pid_0 = tl.program_id(0) % num_blocks_0
115
+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
116
+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
117
+ offset_0 = pid_0 * _BLOCK_SIZE_0
118
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
119
+ mask_0 = indices_0 < d0
120
+ offset_1 = pid_1 * _BLOCK_SIZE_1
121
+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
122
+ mask_1 = indices_1 < d1
123
+ offset_2 = pid_2 * _BLOCK_SIZE_2
124
+ indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
125
+ mask_2 = indices_2 < d2
126
+ load = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_1[None, :, None] * x_stride_1 + indices_2[None, None, :] * x_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_2[None, None, :], other=0)
127
+ load_1 = tl.load(bias1 + (indices_1[None, :, None] * bias1_stride_1 + indices_2[None, None, :] * bias1_stride_2), mask_1[None, :, None] & mask_2[None, None, :], other=0)
128
+ v_0 = load + load_1
129
+ load_2 = tl.load(bias2 + (indices_0[:, None, None] * bias2_stride_0 + indices_2[None, None, :] * bias2_stride_2), mask_0[:, None, None] & mask_2[None, None, :], other=0)
130
+ v_1 = v_0 + load_2
131
+ tl.store(out + (indices_0[:, None, None] * out_stride_0 + indices_1[None, :, None] * out_stride_1 + indices_2[None, None, :] * out_stride_2), v_1, mask_0[:, None, None] & mask_1[None, :, None] & mask_2[None, None, :])
132
+
133
+ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
134
+ d0, d1, d2 = x.size()
135
+ out = torch.empty_like(x)
136
+ _BLOCK_SIZE_0 = 8
137
+ _BLOCK_SIZE_1 = 8
138
+ _BLOCK_SIZE_2 = 8
139
+ _broadcast_add_3d_kernel[triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),](x, bias1, bias2, out, bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, d2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
140
+ return out
141
+
142
+ def _broadcast_add_3d_make_precompiler(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
143
+ d0, d1, d2 = x.size()
144
+ out = torch.empty_like(x)
145
+ _BLOCK_SIZE_0 = 8
146
+ _BLOCK_SIZE_1 = 8
147
+ _BLOCK_SIZE_2 = 8
148
+ from helion.runtime.precompile_shim import make_precompiler
149
+ return make_precompiler(_broadcast_add_3d_kernel)(x, bias1, bias2, out, bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, d2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
150
+
151
+ --- assertExpectedJournal(TestIndexing.test_broadcasting_tensor_descriptor_indexing)
152
+ from __future__ import annotations
153
+
154
+ import torch
155
+ import helion
156
+ import triton
157
+ import triton.language as tl
158
+
159
+ helion.runtime.set_triton_allocator()
160
+
161
+ @triton.jit
162
+ def _broadcast_add_3d_kernel(x, bias1, bias2, out, bias1_size_1, bias1_size_2, bias2_size_0, bias2_size_2, out_size_0, out_size_1, out_size_2, x_size_0, x_size_1, x_size_2, bias1_stride_0, bias1_stride_1, bias1_stride_2, bias2_stride_0, bias2_stride_1, bias2_stride_2, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, d0, d1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
163
+ x_desc = tl.make_tensor_descriptor(x, [x_size_0, x_size_1, x_size_2], [x_stride_0, x_stride_1, x_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2])
164
+ bias1_desc = tl.make_tensor_descriptor(bias1, [1, bias1_size_1, bias1_size_2], [bias1_stride_0, bias1_stride_1, bias1_stride_2], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2])
165
+ bias2_desc = tl.make_tensor_descriptor(bias2, [bias2_size_0, 1, bias2_size_2], [bias2_stride_0, bias2_stride_1, bias2_stride_2], [_BLOCK_SIZE_0, 1, _BLOCK_SIZE_2])
166
+ out_desc = tl.make_tensor_descriptor(out, [out_size_0, out_size_1, out_size_2], [out_stride_0, out_stride_1, out_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2])
167
+ num_blocks_0 = tl.cdiv(d0, _BLOCK_SIZE_0)
168
+ num_blocks_1 = tl.cdiv(d1, _BLOCK_SIZE_1)
169
+ pid_0 = tl.program_id(0) % num_blocks_0
170
+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
171
+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
172
+ offset_0 = pid_0 * _BLOCK_SIZE_0
173
+ offset_1 = pid_1 * _BLOCK_SIZE_1
174
+ offset_2 = pid_2 * _BLOCK_SIZE_2
175
+ load = x_desc.load([offset_0, offset_1, offset_2])
176
+ load_1 = bias1_desc.load([0, offset_1, offset_2])
177
+ v_0 = load + load_1
178
+ load_2 = bias2_desc.load([offset_0, 0, offset_2])
179
+ v_1 = v_0 + load_2
180
+ out_desc.store([offset_0, offset_1, offset_2], v_1)
181
+
182
+ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
183
+ d0, d1, d2 = x.size()
184
+ out = torch.empty_like(x)
185
+ _BLOCK_SIZE_0 = 8
186
+ _BLOCK_SIZE_1 = 8
187
+ _BLOCK_SIZE_2 = 8
188
+ _broadcast_add_3d_kernel[triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),](x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
189
+ return out
190
+
191
+ def _broadcast_add_3d_make_precompiler(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
192
+ d0, d1, d2 = x.size()
193
+ out = torch.empty_like(x)
194
+ _BLOCK_SIZE_0 = 8
195
+ _BLOCK_SIZE_1 = 8
196
+ _BLOCK_SIZE_2 = 8
197
+ from helion.runtime.precompile_shim import make_precompiler
198
+ return make_precompiler(_broadcast_add_3d_kernel)(x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
199
+
61
200
--- assertExpectedJournal(TestIndexing.test_mask_load)
62
201
from __future__ import annotations
63
202
0 commit comments