Skip to content

Commit 3dc5e0e

Browse files
min-jean-chotoyxu
andauthored
Add aten::special_airy_ai.out (#1039)
- `special_airy_ai.out` - `special_airy_ai` --------- Co-authored-by: Yutao Xu <[email protected]>
1 parent f224138 commit 3dc5e0e

File tree

6 files changed

+240
-1
lines changed

6 files changed

+240
-1
lines changed

src/ATen/native/xpu/AiryAi.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#include <ATen/native/UnaryOps.h>
2+
#include <ATen/native/DispatchStub.h>
3+
#include <ATen/native/TensorIterator.h>
4+
#include <ATen/native/xpu/sycl/AiryAiKernel.h>
5+
6+
namespace at {
7+
namespace native {
8+
REGISTER_XPU_DISPATCH(special_airy_ai_stub, &xpu::airy_ai_kernel);
9+
10+
} // namespace native
11+
} // namespace at

src/ATen/native/xpu/XPUFallback.template

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
200200
"segment_reduce",
201201
"_segment_reduce_backward",
202202
"sinc.out",
203-
"special_airy_ai.out",
204203
"_thnn_fused_gru_cell",
205204
"_to_sparse",
206205
"_to_sparse_csr",
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/native/TensorIterator.h>
3+
#include <ATen/native/xpu/sycl/Loops.h>
4+
#include <ATen/native/xpu/sycl/MathExtensions.h>
5+
6+
#include <ATen/native/xpu/sycl/AiryAiKernel.h>
7+
8+
namespace at::native::xpu {
9+
template <typename scalar_t>
10+
struct AiryAiFunctor {
11+
scalar_t operator()(scalar_t a) const {
12+
return airy_ai_forward(a);
13+
}
14+
};
15+
16+
void airy_ai_kernel(TensorIteratorBase& iter) {
17+
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "airy_ai_xpu", [&]() {
18+
gpu_kernel(iter, AiryAiFunctor<scalar_t>());
19+
});
20+
}
21+
22+
} // namespace at::native::xpu
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include <ATen/native/TensorIterator.h>
4+
5+
namespace at::native::xpu {
6+
7+
TORCH_XPU_API void airy_ai_kernel(TensorIteratorBase& iter);
8+
9+
} // namespace at::native::xpu

src/ATen/native/xpu/sycl/MathExtensions.h

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,4 +1318,187 @@ static inline C10_HOST_DEVICE scalar_t bessel_y1_forward(scalar_t x) {
13181318
scalar_t(0.797884560802865355879892119868763737) / std::sqrt(x);
13191319
} // bessel_y1_forward(scalar_t x)
13201320

1321+
template <typename T>
1322+
static inline C10_HOST_DEVICE T airy_ai_forward(T x) {
1323+
static const T AN[] = {
1324+
+3.46538101525629032477e-01f,
1325+
+1.20075952739645805542e+01f,
1326+
+7.62796053615234516538e+01f,
1327+
+1.68089224934630576269e+02f,
1328+
+1.59756391350164413639e+02f,
1329+
+7.05360906840444183113e+01f,
1330+
+1.40264691163389668864e+01f,
1331+
+9.99999999999999995305e-01f,
1332+
};
1333+
1334+
static const T AD[] = {
1335+
+5.67594532638770212846e-01f,
1336+
+1.47562562584847203173e+01f,
1337+
+8.45138970141474626562e+01f,
1338+
+1.77318088145400459522e+02f,
1339+
+1.64234692871529701831e+02f,
1340+
+7.14778400825575695274e+01f,
1341+
+1.40959135607834029598e+01f,
1342+
+1.00000000000000000470e+00f,
1343+
};
1344+
1345+
static const T AFN[] = {
1346+
-1.31696323418331795333e-01f,
1347+
-6.26456544431912369773e-01f,
1348+
-6.93158036036933542233e-01f,
1349+
-2.79779981545119124951e-01f,
1350+
-4.91900132609500318020e-02f,
1351+
-4.06265923594885404393e-03f,
1352+
-1.59276496239262096340e-04f,
1353+
-2.77649108155232920844e-06f,
1354+
-1.67787698489114633780e-08f,
1355+
};
1356+
1357+
static const T AFD[] = {
1358+
+1.33560420706553243746e+01f,
1359+
+3.26825032795224613948e+01f,
1360+
+2.67367040941499554804e+01f,
1361+
+9.18707402907259625840e+00f,
1362+
+1.47529146771666414581e+00f,
1363+
+1.15687173795188044134e-01f,
1364+
+4.40291641615211203805e-03f,
1365+
+7.54720348287414296618e-05f,
1366+
+4.51850092970580378464e-07f,
1367+
};
1368+
1369+
static const T AGN[] = {
1370+
+1.97339932091685679179e-02f,
1371+
+3.91103029615688277255e-01f,
1372+
+1.06579897599595591108e+00f,
1373+
+9.39169229816650230044e-01f,
1374+
+3.51465656105547619242e-01f,
1375+
+6.33888919628925490927e-02f,
1376+
+5.85804113048388458567e-03f,
1377+
+2.82851600836737019778e-04f,
1378+
+6.98793669997260967291e-06f,
1379+
+8.11789239554389293311e-08f,
1380+
+3.41551784765923618484e-10f,
1381+
};
1382+
1383+
static const T AGD[] = {
1384+
+9.30892908077441974853e+00f,
1385+
+1.98352928718312140417e+01f,
1386+
+1.55646628932864612953e+01f,
1387+
+5.47686069422975497931e+00f,
1388+
+9.54293611618961883998e-01f,
1389+
+8.64580826352392193095e-02f,
1390+
+4.12656523824222607191e-03f,
1391+
+1.01259085116509135510e-04f,
1392+
+1.17166733214413521882e-06f,
1393+
+4.91834570062930015649e-09f,
1394+
};
1395+
1396+
int domain_flag = 0;
1397+
1398+
T ai;
1399+
1400+
if (std::isinf(x)) {
1401+
return std::numeric_limits<T>::quiet_NaN();
1402+
}
1403+
1404+
if (x > T(103.892f)) {
1405+
return T(0.0f);
1406+
}
1407+
1408+
T f;
1409+
T g;
1410+
T k;
1411+
1412+
if (x < T(-2.09f)) {
1413+
T z = T(1.0f) / (T(-2.0f) * x * std::sqrt(-x) / T(3.0f));
1414+
1415+
T afn = 0.0f;
1416+
1417+
for (uint8_t index = 0; index <= 8; index++) {
1418+
afn = afn * (z * z) + AFN[index];
1419+
}
1420+
1421+
T afd = 0.0f;
1422+
1423+
for (uint8_t index = 0; index <= 8; index++) {
1424+
afd = afd * (z * z) + AFD[index];
1425+
}
1426+
1427+
T agn = 0.0f;
1428+
1429+
for (uint8_t index = 0; index <= 10 + 0; index++) {
1430+
agn = agn * (z * z) + AGN[index];
1431+
}
1432+
1433+
T agd = 0.0f;
1434+
1435+
for (uint8_t index = 0; index <= 10 - 1; index++) {
1436+
agd = agd * (z * z) + AGD[index];
1437+
}
1438+
1439+
T t = T(-2.0f) * x * std::sqrt(-x) / T(3.0f) +
1440+
T(0.25f) * T(3.14159265358979323846f);
1441+
1442+
return T(5.64189583547756286948e-01f) / std::sqrt(std::sqrt(-x)) *
1443+
(std::sin(t) * (T(1.0f) + z * z * afn / afd) -
1444+
std::cos(t) * (z * agn / agd));
1445+
}
1446+
1447+
if (x >= T(2.09f)) {
1448+
domain_flag = 5;
1449+
1450+
T zeta = T(2.0f) * x * std::sqrt(x) / T(3.0f);
1451+
1452+
T an = 0.0f;
1453+
1454+
for (uint8_t index = 0; index <= 7; index++) {
1455+
an = an * (T(1.0f) / zeta) + AN[index];
1456+
}
1457+
1458+
T ad = 0.0f;
1459+
1460+
for (uint8_t index = 0; index <= 7; index++) {
1461+
ad = ad * (T(1.0f) / zeta) + AD[index];
1462+
}
1463+
1464+
ai = T(5.64189583547756286948e-01f) * (an / ad) /
1465+
(T(2.0f) * std::sqrt(std::sqrt(x)) * std::exp(zeta));
1466+
1467+
if (x > T(8.3203353f)) {
1468+
return ai;
1469+
}
1470+
}
1471+
1472+
f = 1.0f;
1473+
g = x;
1474+
k = 1.0f;
1475+
1476+
T m = 1.0f;
1477+
T n = x;
1478+
T t = 1.0f;
1479+
T z = x * x * x;
1480+
1481+
while (t > std::numeric_limits<T>::epsilon()) {
1482+
m *= z;
1483+
k += T(1.0f);
1484+
m /= k;
1485+
n *= z;
1486+
k += T(1.0f);
1487+
n /= k;
1488+
m /= k;
1489+
f += m;
1490+
k += T(1.0f);
1491+
n /= k;
1492+
g += n;
1493+
1494+
t = std::abs(m / f);
1495+
}
1496+
1497+
if ((domain_flag & 1) == 0) {
1498+
return T(0.355028053887817239260f) * f - T(0.258819403792806798405f) * g;
1499+
}
1500+
1501+
return ai;
1502+
} // T airy_ai(T x)
1503+
13211504
} // namespace at::native::xpu

yaml/native/native_functions.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6730,6 +6730,21 @@
67306730
XPU: angle_out
67316731
tags: pointwise
67326732

6733+
- func: special_airy_ai(Tensor x) -> Tensor
6734+
python_module: special
6735+
structured_delegate: special_airy_ai.out
6736+
variants: function
6737+
tags: pointwise
6738+
6739+
- func: special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
6740+
dispatch:
6741+
XPU: special_airy_ai_out
6742+
python_module: special
6743+
structured_inherits: TensorIteratorBase
6744+
structured: True
6745+
variants: function
6746+
tags: pointwise
6747+
67336748
- func: special_bessel_j0(Tensor self) -> Tensor
67346749
python_module: special
67356750
structured_delegate: special_bessel_j0.out

0 commit comments

Comments
 (0)