@@ -62,6 +62,96 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o
62
62
_launcher(_reduce_kernel_kernel, (n,), x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
63
63
return out
64
64
65
+ --- assertExpectedJournal(TestReductions.test_fp16_math_ops_fp32_fallback)
66
+ from __future__ import annotations
67
+
68
+ import torch
69
+ import triton
70
+ import triton.language as tl
71
+ from torch._inductor.runtime.triton_compat import libdevice
72
+ from helion.runtime import default_launcher as _default_launcher
73
+
74
+ @triton.jit
75
+ def _rsqrt_fp16_kernel_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
76
+ pid_0 = tl.program_id(0)
77
+ offset_0 = pid_0 * _BLOCK_SIZE_0
78
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
79
+ mask_0 = indices_0 < x_size_0
80
+ load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
81
+ v_0 = load.to(tl.float32)
82
+ v_1 = libdevice.rsqrt(v_0)
83
+ v_2 = v_1.to(tl.float16)
84
+ tl.store(result + indices_0 * result_stride_0, v_2, mask_0)
85
+
86
+ def rsqrt_fp16_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
87
+ result = torch.empty_like(x)
88
+ _BLOCK_SIZE_0 = 32
89
+ _launcher(_rsqrt_fp16_kernel_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
90
+ return result
91
+
92
+ --- assertExpectedJournal(TestReductions.test_fp16_math_ops_fp32_fallback)
93
+ from __future__ import annotations
94
+
95
+ import torch
96
+ import triton
97
+ import triton.language as tl
98
+ from torch._inductor.runtime.triton_helpers import math as tl_math
99
+ from torch._inductor.runtime.triton_compat import libdevice
100
+ from helion.runtime import default_launcher as _default_launcher
101
+
102
+ @triton.jit
103
+ def _multi_math_ops_fp16_kernel_kernel(x, result, x_size_0, result_stride_0, result_stride_1, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
104
+ pid_0 = tl.program_id(0)
105
+ offset_0 = pid_0 * _BLOCK_SIZE_0
106
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
107
+ mask_0 = indices_0 < x_size_0
108
+ load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
109
+ v_0 = load.to(tl.float32)
110
+ v_1 = libdevice.rsqrt(v_0)
111
+ v_2 = v_1.to(tl.float16)
112
+ tl.store(result + (indices_0 * result_stride_0 + 0 * result_stride_1), v_2, mask_0)
113
+ load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
114
+ v_3 = load_1.to(tl.float32)
115
+ v_4 = libdevice.sqrt(v_3)
116
+ v_5 = v_4.to(tl.float16)
117
+ tl.store(result + (indices_0 * result_stride_0 + 1 * result_stride_1), v_5, mask_0)
118
+ load_2 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
119
+ v_6 = load_2.to(tl.float32)
120
+ v_7 = tl_math.sin(v_6)
121
+ v_8 = v_7.to(tl.float16)
122
+ tl.store(result + (indices_0 * result_stride_0 + 2 * result_stride_1), v_8, mask_0)
123
+ load_3 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
124
+ v_9 = load_3.to(tl.float32)
125
+ v_10 = tl_math.cos(v_9)
126
+ v_11 = v_10.to(tl.float16)
127
+ tl.store(result + (indices_0 * result_stride_0 + 3 * result_stride_1), v_11, mask_0)
128
+ load_4 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
129
+ v_12 = load_4.to(tl.float32)
130
+ v_13 = tl_math.log(v_12)
131
+ v_14 = v_13.to(tl.float16)
132
+ tl.store(result + (indices_0 * result_stride_0 + 4 * result_stride_1), v_14, mask_0)
133
+ load_5 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
134
+ v_15 = load_5.to(tl.float32)
135
+ v_16 = libdevice.tanh(v_15)
136
+ v_17 = v_16.to(tl.float16)
137
+ tl.store(result + (indices_0 * result_stride_0 + 5 * result_stride_1), v_17, mask_0)
138
+ load_6 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
139
+ v_18 = load_6.to(tl.float32)
140
+ v_19 = libdevice.log1p(v_18)
141
+ v_20 = v_19.to(tl.float16)
142
+ tl.store(result + (indices_0 * result_stride_0 + 6 * result_stride_1), v_20, mask_0)
143
+ load_7 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
144
+ v_21 = load_7.to(tl.float32)
145
+ v_22 = tl_math.exp(v_21)
146
+ v_23 = v_22.to(tl.float16)
147
+ tl.store(result + (indices_0 * result_stride_0 + 7 * result_stride_1), v_23, mask_0)
148
+
149
+ def multi_math_ops_fp16_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
150
+ result = torch.empty([x.size(0), 8], dtype=x.dtype, device=x.device)
151
+ _BLOCK_SIZE_0 = 16
152
+ _launcher(_multi_math_ops_fp16_kernel_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), result.stride(1), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
153
+ return result
154
+
65
155
--- assertExpectedJournal(TestReductions.test_fp16_var_mean)
66
156
from __future__ import annotations
67
157
0 commit comments