@@ -97,3 +97,94 @@ def fn(x: torch.Tensor, s: hl.constexpr, *, _launcher=_default_launcher):
97
97
_BLOCK_SIZE_1 = 16
98
98
_launcher(_fn_kernel, (triton.cdiv(b, _BLOCK_SIZE_0) * triton.cdiv(16, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), b, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
99
99
return out
100
+
101
+ --- assertExpectedJournal(TestConstExpr.test_string_literal_arg)
102
+ from __future__ import annotations
103
+
104
+ import torch
105
+ import triton
106
+ import triton.language as tl
107
+ from helion.runtime import default_launcher as _default_launcher
108
+
109
+ @triton.jit
110
+ def _fn_kernel(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
111
+ num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
112
+ pid_0 = tl.program_id(0) % num_blocks_0
113
+ pid_1 = tl.program_id(0) // num_blocks_0
114
+ offset_0 = pid_0 * _BLOCK_SIZE_0
115
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
116
+ mask_0 = indices_0 < x_size_0
117
+ offset_1 = pid_1 * _BLOCK_SIZE_1
118
+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
119
+ mask_1 = indices_1 < x_size_1
120
+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
121
+ v_0 = 1.0
122
+ v_1 = load + v_0
123
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
124
+
125
+ def fn(x: torch.Tensor, mode: str, *, _launcher=_default_launcher):
126
+ out = torch.empty_like(x)
127
+ _BLOCK_SIZE_0 = 32
128
+ _BLOCK_SIZE_1 = 32
129
+ _launcher(_fn_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
130
+ return out
131
+
132
+ --- assertExpectedJournal(TestConstExpr.test_string_literal_arg)
133
+ from __future__ import annotations
134
+
135
+ import torch
136
+ import triton
137
+ import triton.language as tl
138
+ from helion.runtime import default_launcher as _default_launcher
139
+
140
+ @triton.jit
141
+ def _fn_kernel(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
142
+ num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
143
+ pid_0 = tl.program_id(0) % num_blocks_0
144
+ pid_1 = tl.program_id(0) // num_blocks_0
145
+ offset_0 = pid_0 * _BLOCK_SIZE_0
146
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
147
+ mask_0 = indices_0 < x_size_0
148
+ offset_1 = pid_1 * _BLOCK_SIZE_1
149
+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
150
+ mask_1 = indices_1 < x_size_1
151
+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
152
+ v_0 = 2.0
153
+ v_1 = load * v_0
154
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
155
+
156
+ def fn(x: torch.Tensor, mode: str, *, _launcher=_default_launcher):
157
+ out = torch.empty_like(x)
158
+ _BLOCK_SIZE_0 = 32
159
+ _BLOCK_SIZE_1 = 32
160
+ _launcher(_fn_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
161
+ return out
162
+
163
+ --- assertExpectedJournal(TestConstExpr.test_string_literal_arg)
164
+ from __future__ import annotations
165
+
166
+ import torch
167
+ import triton
168
+ import triton.language as tl
169
+ from helion.runtime import default_launcher as _default_launcher
170
+
171
+ @triton.jit
172
+ def _fn_kernel(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
173
+ num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
174
+ pid_0 = tl.program_id(0) % num_blocks_0
175
+ pid_1 = tl.program_id(0) // num_blocks_0
176
+ offset_0 = pid_0 * _BLOCK_SIZE_0
177
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
178
+ mask_0 = indices_0 < x_size_0
179
+ offset_1 = pid_1 * _BLOCK_SIZE_1
180
+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
181
+ mask_1 = indices_1 < x_size_1
182
+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
183
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), load, mask_0[:, None] & mask_1[None, :])
184
+
185
+ def fn(x: torch.Tensor, mode: str, *, _launcher=_default_launcher):
186
+ out = torch.empty_like(x)
187
+ _BLOCK_SIZE_0 = 32
188
+ _BLOCK_SIZE_1 = 32
189
+ _launcher(_fn_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
190
+ return out
0 commit comments