Skip to content

Commit b4ce7c0

Browse files
authored
[SYCL] Add some trivial util functions for half type. (#6691)
The usage of sycl::half and CUDA half is different, sycl::half overrides arithmetic and comparison operators(+, -, *, /, >, <,==.!=...) while CUDA math provides those arithmetic and comparison functionalities via a bunch of util functions. This PR provides some sycl::half util functions aligning with CUDA math, users who are porting half related CUDA code to SYCL will spend less effort with those util functions' help. Signed-off-by: jinge90 <[email protected]>
1 parent a776f3c commit b4ce7c0

File tree

2 files changed

+320
-0
lines changed

2 files changed

+320
-0
lines changed

sycl/include/sycl/ext/intel/math.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#pragma once
12+
#include <sycl/ext/intel/math/imf_half_trivial.hpp>
1213
#include <sycl/half_type.hpp>
1314
#include <type_traits>
1415

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
//==------------- imf_half_trivial.hpp - trivial half utils ----------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Trival half util functions.
9+
//===----------------------------------------------------------------------===//
10+
11+
#pragma once
12+
#include <sycl/half_type.hpp>
13+
14+
namespace sycl {
15+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
16+
namespace ext {
17+
namespace intel {
18+
namespace math {
19+
sycl::half hadd(sycl::half x, sycl::half y) { return x + y; }
20+
21+
sycl::half hadd_sat(sycl::half x, sycl::half y) {
22+
return sycl::clamp((x + y), sycl::half(0.f), sycl::half(1.0f));
23+
}
24+
25+
sycl::half hfma(sycl::half x, sycl::half y, sycl::half z) {
26+
return sycl::fma(x, y, z);
27+
}
28+
29+
sycl::half hfma_sat(sycl::half x, sycl::half y, sycl::half z) {
30+
return sycl::clamp(sycl::fma(x, y, z), sycl::half(0.f), sycl::half(1.0f));
31+
}
32+
33+
sycl::half hmul(sycl::half x, sycl::half y) { return x * y; }
34+
35+
sycl::half hmul_sat(sycl::half x, sycl::half y) {
36+
return sycl::clamp((x * y), sycl::half(0.f), sycl::half(1.0f));
37+
}
38+
39+
sycl::half hneg(sycl::half x) { return -x; }
40+
41+
sycl::half hsub(sycl::half x, sycl::half y) { return x - y; }
42+
43+
sycl::half hsub_sat(sycl::half x, sycl::half y) {
44+
return sycl::clamp((x - y), sycl::half(0.f), sycl::half(1.0f));
45+
}
46+
47+
sycl::half hdiv(sycl::half x, sycl::half y) { return x / y; }
48+
49+
bool heq(sycl::half x, sycl::half y) { return x == y; }
50+
51+
bool hequ(sycl::half x, sycl::half y) {
52+
if (sycl::isnan(x) || sycl::isnan(y))
53+
return true;
54+
else
55+
return x == y;
56+
}
57+
58+
bool hge(sycl::half x, sycl::half y) { return x >= y; }
59+
60+
bool hgeu(sycl::half x, sycl::half y) {
61+
if (sycl::isnan(x) || sycl::isnan(y))
62+
return true;
63+
else
64+
return x >= y;
65+
}
66+
67+
bool hgt(sycl::half x, sycl::half y) { return x > y; }
68+
69+
bool hgtu(sycl::half x, sycl::half y) {
70+
if (sycl::isnan(x) || sycl::isnan(y))
71+
return true;
72+
else
73+
return x > y;
74+
}
75+
76+
bool hle(sycl::half x, sycl::half y) { return x <= y; }
77+
78+
bool hleu(sycl::half x, sycl::half y) {
79+
if (sycl::isnan(x) || sycl::isnan(y))
80+
return true;
81+
else
82+
return x <= y;
83+
}
84+
85+
bool hlt(sycl::half x, sycl::half y) { return x < y; }
86+
87+
bool hltu(sycl::half x, sycl::half y) {
88+
if (sycl::isnan(x) || sycl::isnan(y))
89+
return true;
90+
return x < y;
91+
}
92+
93+
bool hne(sycl::half x, sycl::half y) {
94+
if (sycl::isnan(x) || sycl::isnan(y))
95+
return false;
96+
return x != y;
97+
}
98+
99+
bool hneu(sycl::half x, sycl::half y) {
100+
if (sycl::isnan(x) || sycl::isnan(y))
101+
return true;
102+
else
103+
return x != y;
104+
}
105+
106+
bool hisinf(sycl::half x) { return sycl::isinf(x); }
107+
bool hisnan(sycl::half y) { return sycl::isnan(y); }
108+
109+
sycl::half2 hadd2(sycl::half2 x, sycl::half2 y) { return x + y; }
110+
111+
sycl::half2 hadd2_sat(sycl::half2 x, sycl::half2 y) {
112+
return sycl::clamp((x + y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
113+
}
114+
115+
sycl::half2 hfma2(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
116+
return sycl::fma(x, y, z);
117+
}
118+
119+
sycl::half2 hfma2_sat(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
120+
return sycl::clamp(sycl::fma(x, y, z), sycl::half2{0.f, 0.f},
121+
sycl::half2{1.f, 1.f});
122+
}
123+
124+
sycl::half2 hmul2(sycl::half2 x, sycl::half2 y) { return x * y; }
125+
126+
sycl::half2 hmul2_sat(sycl::half2 x, sycl::half2 y) {
127+
return sycl::clamp((x * y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
128+
}
129+
130+
sycl::half2 h2div(sycl::half2 x, sycl::half2 y) { return x / y; }
131+
132+
sycl::half2 hneg2(sycl::half2 x) { return -x; }
133+
134+
sycl::half2 hsub2(sycl::half2 x, sycl::half2 y) { return x - y; }
135+
136+
sycl::half2 hsub2_sat(sycl::half2 x, sycl::half2 y) {
137+
return sycl::clamp((x - y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
138+
}
139+
140+
bool hbeq2(sycl::half2 x, sycl::half2 y) {
141+
return heq(x.s0(), y.s0()) && heq(x.s1(), y.s1());
142+
}
143+
144+
bool hbequ2(sycl::half2 x, sycl::half2 y) {
145+
return hequ(x.s0(), y.s0()) && hequ(x.s1(), y.s1());
146+
}
147+
148+
bool hbge2(sycl::half2 x, sycl::half2 y) {
149+
return hge(x.s0(), y.s0()) && hge(x.s1(), y.s1());
150+
}
151+
152+
bool hbgeu2(sycl::half2 x, sycl::half2 y) {
153+
return hgeu(x.s0(), y.s0()) && hgeu(x.s1(), y.s1());
154+
}
155+
156+
bool hbgt2(sycl::half2 x, sycl::half2 y) {
157+
return hgt(x.s0(), y.s0()) && hgt(x.s1(), y.s1());
158+
}
159+
160+
bool hbgtu2(sycl::half2 x, sycl::half2 y) {
161+
return hgtu(x.s0(), y.s0()) && hgtu(x.s1(), y.s1());
162+
}
163+
164+
bool hble2(sycl::half2 x, sycl::half2 y) {
165+
return hle(x.s0(), y.s0()) && hle(x.s1(), y.s1());
166+
}
167+
168+
bool hbleu2(sycl::half2 x, sycl::half2 y) {
169+
return hleu(x.s0(), y.s0()) && hleu(x.s1(), y.s1());
170+
}
171+
172+
bool hblt2(sycl::half2 x, sycl::half2 y) {
173+
return hlt(x.s0(), y.s0()) && hlt(x.s1(), y.s1());
174+
}
175+
176+
bool hbltu2(sycl::half2 x, sycl::half2 y) {
177+
return hltu(x.s0(), y.s0()) && hltu(x.s1(), y.s1());
178+
}
179+
180+
bool hbne2(sycl::half2 x, sycl::half2 y) {
181+
return hne(x.s0(), y.s0()) && hne(x.s1(), y.s1());
182+
}
183+
184+
bool hbneu2(sycl::half2 x, sycl::half2 y) {
185+
return hneu(x.s0(), y.s0()) && hneu(x.s1(), y.s1());
186+
}
187+
188+
sycl::half2 heq2(sycl::half2 x, sycl::half2 y) {
189+
return sycl::half2{(heq(x.s0(), y.s0()) ? 1.0f : 0.f),
190+
(heq(x.s1(), y.s1()) ? 1.0f : 0.f)};
191+
}
192+
193+
sycl::half2 hequ2(sycl::half2 x, sycl::half2 y) {
194+
return sycl::half2{(hequ(x.s0(), y.s0()) ? 1.0f : 0.f),
195+
(hequ(x.s1(), y.s1()) ? 1.0f : 0.f)};
196+
}
197+
198+
sycl::half2 hge2(sycl::half2 x, sycl::half2 y) {
199+
return sycl::half2{(hge(x.s0(), y.s0()) ? 1.0f : 0.f),
200+
(hge(x.s1(), y.s1()) ? 1.0f : 0.f)};
201+
}
202+
203+
sycl::half2 hgeu2(sycl::half2 x, sycl::half2 y) {
204+
return sycl::half2{(hgeu(x.s0(), y.s0()) ? 1.0f : 0.f),
205+
(hgeu(x.s1(), y.s1()) ? 1.0f : 0.f)};
206+
}
207+
208+
sycl::half2 hgt2(sycl::half2 x, sycl::half2 y) {
209+
return sycl::half2{(hgt(x.s0(), y.s0()) ? 1.0f : 0.f),
210+
(hgt(x.s1(), y.s1()) ? 1.0f : 0.f)};
211+
}
212+
213+
sycl::half2 hgtu2(sycl::half2 x, sycl::half2 y) {
214+
return sycl::half2{(hgtu(x.s0(), y.s0()) ? 1.0f : 0.f),
215+
(hgtu(x.s1(), y.s1()) ? 1.0f : 0.f)};
216+
}
217+
218+
sycl::half2 hle2(sycl::half2 x, sycl::half2 y) {
219+
return sycl::half2{(hle(x.s0(), y.s0()) ? 1.0f : 0.f),
220+
(hle(x.s1(), y.s1()) ? 1.0f : 0.f)};
221+
}
222+
223+
sycl::half2 hleu2(sycl::half2 x, sycl::half2 y) {
224+
return sycl::half2{(hleu(x.s0(), y.s0()) ? 1.0f : 0.f),
225+
(hleu(x.s1(), y.s1()) ? 1.0f : 0.f)};
226+
}
227+
228+
sycl::half2 hlt2(sycl::half2 x, sycl::half2 y) {
229+
return sycl::half2{(hlt(x.s0(), y.s0()) ? 1.0f : 0.f),
230+
(hlt(x.s1(), y.s1()) ? 1.0f : 0.f)};
231+
}
232+
233+
sycl::half2 hltu2(sycl::half2 x, sycl::half2 y) {
234+
return sycl::half2{(hltu(x.s0(), y.s0()) ? 1.0f : 0.f),
235+
(hltu(x.s1(), y.s1()) ? 1.0f : 0.f)};
236+
}
237+
238+
sycl::half2 hisnan2(sycl::half2 x) {
239+
return sycl::half2{(hisnan(x.s0()) ? 1.0f : 0.f),
240+
(hisnan(x.s1()) ? 1.0f : 0.f)};
241+
}
242+
243+
sycl::half2 hne2(sycl::half2 x, sycl::half2 y) {
244+
return sycl::half2{(hne(x.s0(), y.s0()) ? 1.0f : 0.f),
245+
(hne(x.s1(), y.s1()) ? 1.0f : 0.f)};
246+
}
247+
248+
sycl::half2 hneu2(sycl::half2 x, sycl::half2 y) {
249+
return sycl::half2{(hneu(x.s0(), y.s0()) ? 1.0f : 0.f),
250+
(hneu(x.s1(), y.s1()) ? 1.0f : 0.f)};
251+
}
252+
253+
sycl::half hmax(sycl::half x, sycl::half y) { return sycl::fmax(x, y); }
254+
255+
sycl::half hmax_nan(sycl::half x, sycl::half y) {
256+
if (hisnan(x) || hisnan(y))
257+
return sycl::half(NAN);
258+
else
259+
return sycl::fmax(x, y);
260+
}
261+
262+
sycl::half2 hmax2(sycl::half2 x, sycl::half2 y) {
263+
return sycl::half2{hmax(x.s0(), y.s0()), hmax(x.s1(), y.s1())};
264+
}
265+
266+
sycl::half2 hmax2_nan(sycl::half2 x, sycl::half2 y) {
267+
return sycl::half2{hmax_nan(x.s0(), y.s0()), hmax_nan(x.s1(), y.s1())};
268+
}
269+
270+
sycl::half hmin(sycl::half x, sycl::half y) { return sycl::fmin(x, y); }
271+
272+
sycl::half hmin_nan(sycl::half x, sycl::half y) {
273+
if (hisnan(x) || hisnan(y))
274+
return sycl::half(NAN);
275+
else
276+
return sycl::fmin(x, y);
277+
}
278+
279+
sycl::half2 hmin2(sycl::half2 x, sycl::half2 y) {
280+
return sycl::half2{hmin(x.s0(), y.s0()), hmin(x.s1(), y.s1())};
281+
}
282+
283+
sycl::half2 hmin2_nan(sycl::half2 x, sycl::half2 y) {
284+
return sycl::half2{hmin_nan(x.s0(), y.s0()), hmin_nan(x.s1(), y.s1())};
285+
}
286+
287+
sycl::half2 hcmadd(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
288+
return sycl::half2{x.s0() * y.s0() - x.s1() * y.s1() + z.s0(),
289+
x.s0() * y.s1() + x.s1() * y.s0() + z.s1()};
290+
}
291+
292+
sycl::half hfma_relu(sycl::half x, sycl::half y, sycl::half z) {
293+
sycl::half r = sycl::fma(x, y, z);
294+
if (!hisnan(r)) {
295+
if (r < 0.f)
296+
return sycl::half{0.f};
297+
else
298+
return r;
299+
}
300+
return r;
301+
}
302+
303+
sycl::half2 hfma2_relu(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
304+
sycl::half2 r = sycl::fma(x, y, z);
305+
if (!hisnan(r.s0()) && r.s0() < 0.f)
306+
r.s0() = 0.f;
307+
if (!hisnan(r.s1()) && r.s1() < 0.f)
308+
r.s1() = 0.f;
309+
return r;
310+
}
311+
312+
sycl::half habs(sycl::half x) { return sycl::fabs(x); }
313+
314+
sycl::half2 habs2(sycl::half2 x) { return sycl::fabs(x); }
315+
} // namespace math
316+
} // namespace intel
317+
} // namespace ext
318+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
319+
} // namespace sycl

0 commit comments

Comments
 (0)