Skip to content

Commit 4808787

Browse files
authored
Add fallbacks for unary ops that don't support fp16 (#361)
1 parent ca78cd2 commit 4808787

File tree

4 files changed

+245
-8
lines changed

4 files changed

+245
-8
lines changed

helion/_compiler/inductor_lowering_extra.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,54 @@
77
from typing import Generator
88

99
import torch
10+
from torch._inductor.ir import TensorBox
11+
from torch._inductor.lowering import lowerings as original_lowerings
1012
from torch._inductor.lowering import to_dtype
1113
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
1214

1315
inductor_lowering_dispatch: dict[Callable[..., Any] | str, Callable[..., Any]] = {}
1416

1517

18+
def create_fp16_to_fp32_unary_fallback_lowering(
19+
original_op: Callable[..., object],
20+
) -> Callable[..., object]:
21+
"""Create a lowering that converts fp16/bfloat16 inputs to fp32 before calling the operation."""
22+
23+
@functools.wraps(original_op)
24+
def fp32_fallback_lowering(x: object) -> object:
25+
if isinstance(x, TensorBox) and (original_dtype := x.get_dtype()) in (
26+
torch.float16,
27+
torch.bfloat16,
28+
):
29+
x_fp32 = to_dtype(x, torch.float32)
30+
result_fp32 = original_op(x_fp32)
31+
assert isinstance(result_fp32, TensorBox)
32+
return to_dtype(result_fp32, original_dtype)
33+
return original_op(x)
34+
35+
return fp32_fallback_lowering
36+
37+
38+
# Operations that need fp32 fallbacks due to libdevice/tl_math limitations
39+
FP32_FALLBACK_OPS_UNARY = [
40+
torch.ops.aten.rsqrt.default,
41+
torch.ops.aten.sqrt.default,
42+
torch.ops.aten.sin.default,
43+
torch.ops.aten.cos.default,
44+
torch.ops.aten.log.default,
45+
torch.ops.aten.tanh.default,
46+
torch.ops.aten.log1p.default,
47+
torch.ops.aten.expm1.default,
48+
torch.ops.aten.exp.default,
49+
]
50+
51+
# Register fp32 fallback lowerings for ops that don't support fp16/bfloat16
52+
for op in FP32_FALLBACK_OPS_UNARY:
53+
inductor_lowering_dispatch[op] = create_fp16_to_fp32_unary_fallback_lowering(
54+
original_lowerings[op]
55+
)
56+
57+
1658
@contextlib.contextmanager
1759
def patch_inductor_lowerings() -> Generator[None, Any, Any]:
1860
"""Context manager to temporarily patch the inductor lowering table.

test/test_reductions.expected

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,96 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o
6262
_launcher(_reduce_kernel_kernel, (n,), x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
6363
return out
6464

65+
--- assertExpectedJournal(TestReductions.test_fp16_math_ops_fp32_fallback)
66+
from __future__ import annotations
67+
68+
import torch
69+
import triton
70+
import triton.language as tl
71+
from torch._inductor.runtime.triton_compat import libdevice
72+
from helion.runtime import default_launcher as _default_launcher
73+
74+
@triton.jit
75+
def _rsqrt_fp16_kernel_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
76+
pid_0 = tl.program_id(0)
77+
offset_0 = pid_0 * _BLOCK_SIZE_0
78+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
79+
mask_0 = indices_0 < x_size_0
80+
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
81+
v_0 = load.to(tl.float32)
82+
v_1 = libdevice.rsqrt(v_0)
83+
v_2 = v_1.to(tl.float16)
84+
tl.store(result + indices_0 * result_stride_0, v_2, mask_0)
85+
86+
def rsqrt_fp16_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
87+
result = torch.empty_like(x)
88+
_BLOCK_SIZE_0 = 32
89+
_launcher(_rsqrt_fp16_kernel_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
90+
return result
91+
92+
--- assertExpectedJournal(TestReductions.test_fp16_math_ops_fp32_fallback)
93+
from __future__ import annotations
94+
95+
import torch
96+
import triton
97+
import triton.language as tl
98+
from torch._inductor.runtime.triton_helpers import math as tl_math
99+
from torch._inductor.runtime.triton_compat import libdevice
100+
from helion.runtime import default_launcher as _default_launcher
101+
102+
@triton.jit
103+
def _multi_math_ops_fp16_kernel_kernel(x, result, x_size_0, result_stride_0, result_stride_1, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
104+
pid_0 = tl.program_id(0)
105+
offset_0 = pid_0 * _BLOCK_SIZE_0
106+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
107+
mask_0 = indices_0 < x_size_0
108+
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
109+
v_0 = load.to(tl.float32)
110+
v_1 = libdevice.rsqrt(v_0)
111+
v_2 = v_1.to(tl.float16)
112+
tl.store(result + (indices_0 * result_stride_0 + 0 * result_stride_1), v_2, mask_0)
113+
load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
114+
v_3 = load_1.to(tl.float32)
115+
v_4 = libdevice.sqrt(v_3)
116+
v_5 = v_4.to(tl.float16)
117+
tl.store(result + (indices_0 * result_stride_0 + 1 * result_stride_1), v_5, mask_0)
118+
load_2 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
119+
v_6 = load_2.to(tl.float32)
120+
v_7 = tl_math.sin(v_6)
121+
v_8 = v_7.to(tl.float16)
122+
tl.store(result + (indices_0 * result_stride_0 + 2 * result_stride_1), v_8, mask_0)
123+
load_3 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
124+
v_9 = load_3.to(tl.float32)
125+
v_10 = tl_math.cos(v_9)
126+
v_11 = v_10.to(tl.float16)
127+
tl.store(result + (indices_0 * result_stride_0 + 3 * result_stride_1), v_11, mask_0)
128+
load_4 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
129+
v_12 = load_4.to(tl.float32)
130+
v_13 = tl_math.log(v_12)
131+
v_14 = v_13.to(tl.float16)
132+
tl.store(result + (indices_0 * result_stride_0 + 4 * result_stride_1), v_14, mask_0)
133+
load_5 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
134+
v_15 = load_5.to(tl.float32)
135+
v_16 = libdevice.tanh(v_15)
136+
v_17 = v_16.to(tl.float16)
137+
tl.store(result + (indices_0 * result_stride_0 + 5 * result_stride_1), v_17, mask_0)
138+
load_6 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
139+
v_18 = load_6.to(tl.float32)
140+
v_19 = libdevice.log1p(v_18)
141+
v_20 = v_19.to(tl.float16)
142+
tl.store(result + (indices_0 * result_stride_0 + 6 * result_stride_1), v_20, mask_0)
143+
load_7 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
144+
v_21 = load_7.to(tl.float32)
145+
v_22 = tl_math.exp(v_21)
146+
v_23 = v_22.to(tl.float16)
147+
tl.store(result + (indices_0 * result_stride_0 + 7 * result_stride_1), v_23, mask_0)
148+
149+
def multi_math_ops_fp16_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
150+
result = torch.empty([x.size(0), 8], dtype=x.dtype, device=x.device)
151+
_BLOCK_SIZE_0 = 16
152+
_launcher(_multi_math_ops_fp16_kernel_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), result.stride(1), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
153+
return result
154+
65155
--- assertExpectedJournal(TestReductions.test_fp16_var_mean)
66156
from __future__ import annotations
67157

test/test_reductions.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,107 @@ def layer_norm_fwd_repro(
230230
self.assertExpectedJournal(code2)
231231
torch.testing.assert_close(result1, result2, rtol=1e-3, atol=1e-3)
232232

233+
def test_fp16_math_ops_fp32_fallback(self):
234+
"""Test that mathematical ops with fp16/bfloat16 inputs now work via fp32 fallback."""
235+
236+
@helion.kernel(use_default_config=True)
237+
def rsqrt_fp16_kernel(x: torch.Tensor) -> torch.Tensor:
238+
result = torch.empty_like(x)
239+
for tile in hl.tile(x.size(0)):
240+
# This should now work via fp32 fallback
241+
result[tile] = torch.rsqrt(x[tile])
242+
return result
243+
244+
@helion.kernel(use_default_config=True)
245+
def multi_math_ops_fp16_kernel(x: torch.Tensor) -> torch.Tensor:
246+
result = torch.empty([x.size(0), 8], dtype=x.dtype, device=x.device)
247+
for tile in hl.tile(x.size(0)):
248+
# Test multiple operations that have confirmed fallbacks
249+
result[tile, 0] = torch.rsqrt(x[tile])
250+
result[tile, 1] = torch.sqrt(x[tile])
251+
result[tile, 2] = torch.sin(x[tile])
252+
result[tile, 3] = torch.cos(x[tile])
253+
result[tile, 4] = torch.log(x[tile])
254+
result[tile, 5] = torch.tanh(x[tile])
255+
result[tile, 6] = torch.log1p(x[tile])
256+
result[tile, 7] = torch.exp(x[tile])
257+
return result
258+
259+
# Test with float16 - should now succeed
260+
x_fp16 = (
261+
torch.abs(torch.randn([32], device=DEVICE, dtype=torch.float16)) + 0.1
262+
) # positive values for rsqrt
263+
264+
code, result = code_and_output(rsqrt_fp16_kernel, (x_fp16,))
265+
self.assertExpectedJournal(code)
266+
267+
# Verify result is correct compared to PyTorch's rsqrt
268+
expected = torch.rsqrt(x_fp16)
269+
torch.testing.assert_close(result, expected, rtol=1e-3, atol=1e-3)
270+
271+
# Verify result maintains fp16 dtype
272+
self.assertEqual(result.dtype, torch.float16)
273+
274+
# Test multiple math operations
275+
x_multi = torch.abs(torch.randn([16], device=DEVICE, dtype=torch.float16)) + 0.1
276+
code_multi, result_multi = code_and_output(
277+
multi_math_ops_fp16_kernel, (x_multi,)
278+
)
279+
self.assertExpectedJournal(code_multi)
280+
281+
# Verify each operation's correctness
282+
expected_rsqrt = torch.rsqrt(x_multi)
283+
expected_sqrt = torch.sqrt(x_multi)
284+
expected_sin = torch.sin(x_multi)
285+
expected_cos = torch.cos(x_multi)
286+
expected_log = torch.log(x_multi)
287+
expected_tanh = torch.tanh(x_multi)
288+
expected_log1p = torch.log1p(x_multi)
289+
expected_exp = torch.exp(x_multi)
290+
291+
torch.testing.assert_close(
292+
result_multi[:, 0], expected_rsqrt, rtol=1e-3, atol=1e-3
293+
)
294+
torch.testing.assert_close(
295+
result_multi[:, 1], expected_sqrt, rtol=1e-3, atol=1e-3
296+
)
297+
torch.testing.assert_close(
298+
result_multi[:, 2], expected_sin, rtol=1e-3, atol=1e-3
299+
)
300+
torch.testing.assert_close(
301+
result_multi[:, 3], expected_cos, rtol=1e-3, atol=1e-3
302+
)
303+
torch.testing.assert_close(
304+
result_multi[:, 4], expected_log, rtol=1e-3, atol=1e-3
305+
)
306+
torch.testing.assert_close(
307+
result_multi[:, 5], expected_tanh, rtol=1e-3, atol=1e-3
308+
)
309+
torch.testing.assert_close(
310+
result_multi[:, 6], expected_log1p, rtol=1e-3, atol=1e-3
311+
)
312+
torch.testing.assert_close(
313+
result_multi[:, 7], expected_exp, rtol=1e-3, atol=1e-3
314+
)
315+
316+
# Verify all results maintain fp16 dtype
317+
self.assertEqual(result_multi.dtype, torch.float16)
318+
319+
# Test with bfloat16 if available
320+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
321+
x_bf16 = (
322+
torch.abs(torch.randn([32], device=DEVICE, dtype=torch.bfloat16)) + 0.1
323+
)
324+
325+
code_bf16, result_bf16 = code_and_output(rsqrt_fp16_kernel, (x_bf16,))
326+
327+
# Verify bfloat16 result is correct
328+
expected_bf16 = torch.rsqrt(x_bf16)
329+
torch.testing.assert_close(result_bf16, expected_bf16, rtol=1e-2, atol=1e-2)
330+
331+
# Verify result maintains bfloat16 dtype
332+
self.assertEqual(result_bf16.dtype, torch.bfloat16)
333+
233334

234335
if __name__ == "__main__":
235336
unittest.main()

test/test_views.expected

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ def _softmax_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1,
2222
amax = tl.max(_mask_to, 1)
2323
amax_1 = amax[:, None]
2424
v_0 = values - amax_1
25-
v_1 = tl_math.exp(v_0)
26-
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_1, 0)
25+
v_1 = v_0.to(tl.float32)
26+
v_2 = tl_math.exp(v_1)
27+
v_3 = v_2.to(tl.float16)
28+
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_3, 0)
2729
sum_1 = tl.sum(_mask_to_1, 1)
2830
sum_exp = sum_1[None, :]
29-
v_2 = v_1 / sum_exp
30-
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_2, mask_1[None, :])
31+
v_4 = v_3 / sum_exp
32+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_4, mask_1[None, :])
3133

3234
def softmax(x: torch.Tensor, *, _launcher=_default_launcher):
3335
n, _m = x.size()
@@ -57,12 +59,14 @@ def _softmax_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1,
5759
amax = tl.max(_mask_to, 1)
5860
amax_1 = tl.reshape(amax, [1, 1])
5961
v_0 = values - amax_1
60-
v_1 = tl_math.exp(v_0)
61-
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_1, 0)
62+
v_1 = v_0.to(tl.float32)
63+
v_2 = tl_math.exp(v_1)
64+
v_3 = v_2.to(tl.float16)
65+
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_3, 0)
6266
sum_1 = tl.sum(_mask_to_1, 1)
6367
sum_exp = tl.reshape(sum_1, [1, 1])
64-
v_2 = v_1 / sum_exp
65-
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_2, mask_1[None, :])
68+
v_4 = v_3 / sum_exp
69+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_4, mask_1[None, :])
6670

6771
def softmax(x: torch.Tensor, *, _launcher=_default_launcher):
6872
n, _m = x.size()

0 commit comments

Comments
 (0)