Skip to content

Commit 975959b

Browse files
authored
Allow string literal args (#353)
1 parent 78e4663 commit 975959b

File tree

3 files changed

+123
-0
lines changed

3 files changed

+123
-0
lines changed

helion/_compiler/compile_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def to_fake(self, obj: object, origin: Origin) -> object:
233233
return lift_closures(fn, origin)
234234
if isinstance(obj, ConstExpr):
235235
return obj.value
236+
if isinstance(obj, str):
237+
return obj
236238
if isinstance(obj, list):
237239
return [self.to_fake(e, origin) for e in obj]
238240
if isinstance(obj, tuple) and hasattr(obj, "_fields"):

test/test_constexpr.expected

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,94 @@ def fn(x: torch.Tensor, s: hl.constexpr, *, _launcher=_default_launcher):
9797
_BLOCK_SIZE_1 = 16
9898
_launcher(_fn_kernel, (triton.cdiv(b, _BLOCK_SIZE_0) * triton.cdiv(16, _BLOCK_SIZE_1),), x, out, out.stride(0), out.stride(1), x.stride(0), b, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
9999
return out
100+
101+
--- assertExpectedJournal(TestConstExpr.test_string_literal_arg)
102+
from __future__ import annotations
103+
104+
import torch
105+
import triton
106+
import triton.language as tl
107+
from helion.runtime import default_launcher as _default_launcher
108+
109+
@triton.jit
110+
def _fn_kernel(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
111+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
112+
pid_0 = tl.program_id(0) % num_blocks_0
113+
pid_1 = tl.program_id(0) // num_blocks_0
114+
offset_0 = pid_0 * _BLOCK_SIZE_0
115+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
116+
mask_0 = indices_0 < x_size_0
117+
offset_1 = pid_1 * _BLOCK_SIZE_1
118+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
119+
mask_1 = indices_1 < x_size_1
120+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
121+
v_0 = 1.0
122+
v_1 = load + v_0
123+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
124+
125+
def fn(x: torch.Tensor, mode: str, *, _launcher=_default_launcher):
126+
out = torch.empty_like(x)
127+
_BLOCK_SIZE_0 = 32
128+
_BLOCK_SIZE_1 = 32
129+
_launcher(_fn_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
130+
return out
131+
132+
--- assertExpectedJournal(TestConstExpr.test_string_literal_arg)
133+
from __future__ import annotations
134+
135+
import torch
136+
import triton
137+
import triton.language as tl
138+
from helion.runtime import default_launcher as _default_launcher
139+
140+
@triton.jit
141+
def _fn_kernel(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
142+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
143+
pid_0 = tl.program_id(0) % num_blocks_0
144+
pid_1 = tl.program_id(0) // num_blocks_0
145+
offset_0 = pid_0 * _BLOCK_SIZE_0
146+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
147+
mask_0 = indices_0 < x_size_0
148+
offset_1 = pid_1 * _BLOCK_SIZE_1
149+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
150+
mask_1 = indices_1 < x_size_1
151+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
152+
v_0 = 2.0
153+
v_1 = load * v_0
154+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
155+
156+
def fn(x: torch.Tensor, mode: str, *, _launcher=_default_launcher):
157+
out = torch.empty_like(x)
158+
_BLOCK_SIZE_0 = 32
159+
_BLOCK_SIZE_1 = 32
160+
_launcher(_fn_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
161+
return out
162+
163+
--- assertExpectedJournal(TestConstExpr.test_string_literal_arg)
164+
from __future__ import annotations
165+
166+
import torch
167+
import triton
168+
import triton.language as tl
169+
from helion.runtime import default_launcher as _default_launcher
170+
171+
@triton.jit
172+
def _fn_kernel(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
173+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
174+
pid_0 = tl.program_id(0) % num_blocks_0
175+
pid_1 = tl.program_id(0) // num_blocks_0
176+
offset_0 = pid_0 * _BLOCK_SIZE_0
177+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
178+
mask_0 = indices_0 < x_size_0
179+
offset_1 = pid_1 * _BLOCK_SIZE_1
180+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
181+
mask_1 = indices_1 < x_size_1
182+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
183+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), load, mask_0[:, None] & mask_1[None, :])
184+
185+
def fn(x: torch.Tensor, mode: str, *, _launcher=_default_launcher):
186+
out = torch.empty_like(x)
187+
_BLOCK_SIZE_0 = 32
188+
_BLOCK_SIZE_1 = 32
189+
_launcher(_fn_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
190+
return out

test/test_constexpr.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,36 @@ def fn(x: torch.Tensor, s: hl.constexpr) -> torch.Tensor:
6161
torch.testing.assert_close(result, x.view(-1, 1).expand(512, 16))
6262
self.assertExpectedJournal(code)
6363

64+
def test_string_literal_arg(self):
65+
@helion.kernel()
66+
def fn(x: torch.Tensor, mode: str) -> torch.Tensor:
67+
out = torch.empty_like(x)
68+
for tile in hl.tile(x.size()):
69+
if mode == "add":
70+
out[tile] = x[tile] + 1.0
71+
elif mode == "mul":
72+
out[tile] = x[tile] * 2.0
73+
else:
74+
out[tile] = x[tile]
75+
return out
76+
77+
x = torch.randn([512, 512], device=DEVICE)
78+
79+
# Test "add" mode
80+
code, result = code_and_output(fn, (x, "add"))
81+
torch.testing.assert_close(result, x + 1.0)
82+
self.assertExpectedJournal(code)
83+
84+
# Test "mul" mode
85+
code, result = code_and_output(fn, (x, "mul"))
86+
torch.testing.assert_close(result, x * 2.0)
87+
self.assertExpectedJournal(code)
88+
89+
# Test default mode
90+
code, result = code_and_output(fn, (x, "default"))
91+
torch.testing.assert_close(result, x)
92+
self.assertExpectedJournal(code)
93+
6494

6595
if __name__ == "__main__":
6696
unittest.main()

0 commit comments

Comments
 (0)