@@ -1318,4 +1318,187 @@ static inline C10_HOST_DEVICE scalar_t bessel_y1_forward(scalar_t x) {
13181318 scalar_t (0.797884560802865355879892119868763737 ) / std::sqrt (x);
13191319} // bessel_y1_forward(scalar_t x)
13201320
1321+ template <typename T>
1322+ static inline C10_HOST_DEVICE T airy_ai_forward (T x) {
1323+ static const T AN[] = {
1324+ +3 .46538101525629032477e-01f ,
1325+ +1 .20075952739645805542e+01f ,
1326+ +7 .62796053615234516538e+01f ,
1327+ +1 .68089224934630576269e+02f ,
1328+ +1 .59756391350164413639e+02f ,
1329+ +7 .05360906840444183113e+01f ,
1330+ +1 .40264691163389668864e+01f ,
1331+ +9 .99999999999999995305e-01f ,
1332+ };
1333+
1334+ static const T AD[] = {
1335+ +5 .67594532638770212846e-01f ,
1336+ +1 .47562562584847203173e+01f ,
1337+ +8 .45138970141474626562e+01f ,
1338+ +1 .77318088145400459522e+02f ,
1339+ +1 .64234692871529701831e+02f ,
1340+ +7 .14778400825575695274e+01f ,
1341+ +1 .40959135607834029598e+01f ,
1342+ +1 .00000000000000000470e+00f ,
1343+ };
1344+
1345+ static const T AFN[] = {
1346+ -1 .31696323418331795333e-01f ,
1347+ -6 .26456544431912369773e-01f ,
1348+ -6 .93158036036933542233e-01f ,
1349+ -2 .79779981545119124951e-01f ,
1350+ -4 .91900132609500318020e-02f ,
1351+ -4 .06265923594885404393e-03f ,
1352+ -1 .59276496239262096340e-04f ,
1353+ -2 .77649108155232920844e-06f ,
1354+ -1 .67787698489114633780e-08f ,
1355+ };
1356+
1357+ static const T AFD[] = {
1358+ +1 .33560420706553243746e+01f ,
1359+ +3 .26825032795224613948e+01f ,
1360+ +2 .67367040941499554804e+01f ,
1361+ +9 .18707402907259625840e+00f ,
1362+ +1 .47529146771666414581e+00f ,
1363+ +1 .15687173795188044134e-01f ,
1364+ +4 .40291641615211203805e-03f ,
1365+ +7 .54720348287414296618e-05f ,
1366+ +4 .51850092970580378464e-07f ,
1367+ };
1368+
1369+ static const T AGN[] = {
1370+ +1 .97339932091685679179e-02f ,
1371+ +3 .91103029615688277255e-01f ,
1372+ +1 .06579897599595591108e+00f ,
1373+ +9 .39169229816650230044e-01f ,
1374+ +3 .51465656105547619242e-01f ,
1375+ +6 .33888919628925490927e-02f ,
1376+ +5 .85804113048388458567e-03f ,
1377+ +2 .82851600836737019778e-04f ,
1378+ +6 .98793669997260967291e-06f ,
1379+ +8 .11789239554389293311e-08f ,
1380+ +3 .41551784765923618484e-10f ,
1381+ };
1382+
1383+ static const T AGD[] = {
1384+ +9 .30892908077441974853e+00f ,
1385+ +1 .98352928718312140417e+01f ,
1386+ +1 .55646628932864612953e+01f ,
1387+ +5 .47686069422975497931e+00f ,
1388+ +9 .54293611618961883998e-01f ,
1389+ +8 .64580826352392193095e-02f ,
1390+ +4 .12656523824222607191e-03f ,
1391+ +1 .01259085116509135510e-04f ,
1392+ +1 .17166733214413521882e-06f ,
1393+ +4 .91834570062930015649e-09f ,
1394+ };
1395+
1396+ int domain_flag = 0 ;
1397+
1398+ T ai;
1399+
1400+ if (std::isinf (x)) {
1401+ return std::numeric_limits<T>::quiet_NaN ();
1402+ }
1403+
1404+ if (x > T (103 .892f )) {
1405+ return T (0 .0f );
1406+ }
1407+
1408+ T f;
1409+ T g;
1410+ T k;
1411+
1412+ if (x < T (-2 .09f )) {
1413+ T z = T (1 .0f ) / (T (-2 .0f ) * x * std::sqrt (-x) / T (3 .0f ));
1414+
1415+ T afn = 0 .0f ;
1416+
1417+ for (uint8_t index = 0 ; index <= 8 ; index++) {
1418+ afn = afn * (z * z) + AFN[index];
1419+ }
1420+
1421+ T afd = 0 .0f ;
1422+
1423+ for (uint8_t index = 0 ; index <= 8 ; index++) {
1424+ afd = afd * (z * z) + AFD[index];
1425+ }
1426+
1427+ T agn = 0 .0f ;
1428+
1429+ for (uint8_t index = 0 ; index <= 10 + 0 ; index++) {
1430+ agn = agn * (z * z) + AGN[index];
1431+ }
1432+
1433+ T agd = 0 .0f ;
1434+
1435+ for (uint8_t index = 0 ; index <= 10 - 1 ; index++) {
1436+ agd = agd * (z * z) + AGD[index];
1437+ }
1438+
1439+ T t = T (-2 .0f ) * x * std::sqrt (-x) / T (3 .0f ) +
1440+ T (0 .25f ) * T (3 .14159265358979323846f );
1441+
1442+ return T (5 .64189583547756286948e-01f ) / std::sqrt (std::sqrt (-x)) *
1443+ (std::sin (t) * (T (1 .0f ) + z * z * afn / afd) -
1444+ std::cos (t) * (z * agn / agd));
1445+ }
1446+
1447+ if (x >= T (2 .09f )) {
1448+ domain_flag = 5 ;
1449+
1450+ T zeta = T (2 .0f ) * x * std::sqrt (x) / T (3 .0f );
1451+
1452+ T an = 0 .0f ;
1453+
1454+ for (uint8_t index = 0 ; index <= 7 ; index++) {
1455+ an = an * (T (1 .0f ) / zeta) + AN[index];
1456+ }
1457+
1458+ T ad = 0 .0f ;
1459+
1460+ for (uint8_t index = 0 ; index <= 7 ; index++) {
1461+ ad = ad * (T (1 .0f ) / zeta) + AD[index];
1462+ }
1463+
1464+ ai = T (5 .64189583547756286948e-01f ) * (an / ad) /
1465+ (T (2 .0f ) * std::sqrt (std::sqrt (x)) * std::exp (zeta));
1466+
1467+ if (x > T (8 .3203353f )) {
1468+ return ai;
1469+ }
1470+ }
1471+
1472+ f = 1 .0f ;
1473+ g = x;
1474+ k = 1 .0f ;
1475+
1476+ T m = 1 .0f ;
1477+ T n = x;
1478+ T t = 1 .0f ;
1479+ T z = x * x * x;
1480+
1481+ while (t > std::numeric_limits<T>::epsilon ()) {
1482+ m *= z;
1483+ k += T (1 .0f );
1484+ m /= k;
1485+ n *= z;
1486+ k += T (1 .0f );
1487+ n /= k;
1488+ m /= k;
1489+ f += m;
1490+ k += T (1 .0f );
1491+ n /= k;
1492+ g += n;
1493+
1494+ t = std::abs (m / f);
1495+ }
1496+
1497+ if ((domain_flag & 1 ) == 0 ) {
1498+ return T (0 .355028053887817239260f ) * f - T (0 .258819403792806798405f ) * g;
1499+ }
1500+
1501+ return ai;
1502+ } // T airy_ai(T x)
1503+
13211504} // namespace at::native::xpu
0 commit comments