@@ -2115,6 +2115,7 @@ def expi_jvp(primals, tangents):
21152115 (x_dot ,) = tangents
21162116 return expi (x ), jnp .exp (x ) / x * x_dot
21172117
2118+
21182119@custom_derivatives .custom_jvp
21192120@jit
21202121def sici (x : ArrayLike ) -> tuple [Array , Array ]:
@@ -2151,6 +2152,35 @@ def sici(x: ArrayLike) -> tuple[Array, Array]:
21512152 f"Argument `x` to sici must be real-valued. Got dtype { x .dtype } ."
21522153 )
21532154
2155+ x_abs = jnp .abs (x )
2156+
2157+ si_series , ci_series = _sici_series (x_abs )
2158+ si_asymp , ci_asymp = _sici_asympt (x_abs )
2159+ si_approx , ci_approx = _sici_approx (x_abs )
2160+
2161+ cond1 = x_abs <= 4
2162+ cond2 = (x_abs > 4 ) & (x_abs <= 1e9 )
2163+
2164+ si = jnp .select ([cond1 , cond2 ], [si_series , si_asymp ], si_approx )
2165+ ci = jnp .select ([cond1 , cond2 ], [ci_series , ci_asymp ], ci_approx )
2166+
2167+ si = jnp .sign (x ) * si
2168+ ci = jnp .where (isneginf (x ), np .nan , ci )
2169+
2170+ return si , ci
2171+
2172+ def _sici_approx (x : Array ):
2173+ # sici approximation valid for x >= 1E9
2174+ si = (np .pi / 2 ) - jnp .cos (x ) / x
2175+ ci = jnp .sin (x ) / x
2176+
2177+ si = jnp .where (isposinf (x ), np .pi / 2 , si )
2178+ ci = jnp .where (isposinf (x ), 0.0 , ci )
2179+
2180+ return si , ci
2181+
2182+ def _sici_series (x : Array ):
2183+ # sici series valid for x >= 0 and x <= 4
21542184 def si_series (x ):
21552185 # Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
21562186 SN = np .array ([- 8.39167827910303881427E-11 ,
@@ -2185,19 +2215,125 @@ def ci_series(x):
21852215 t = x * x
21862216 return np .euler_gamma + jnp .log (x ) + t * jnp .polyval (CN , t ) / jnp .polyval (CD , t )
21872217
2188- si = jnp .piecewise (
2189- jnp . abs ( x ) ,
2190- [ x == 0 , jnp . isinf ( x )] ,
2191- [ 0.0 , np . pi / 2 , si_series ]
2218+ si = jnp .where (
2219+ x == 0 ,
2220+ 0.0 ,
2221+ si_series ( x )
21922222 )
21932223
2194- ci = jnp .piecewise (
2195- jnp . abs ( x ) ,
2196- [ x == 0 , isposinf ( x ), isneginf ( x )] ,
2197- [ - np . inf , 0.0 , np . nan , ci_series ]
2224+ ci = jnp .where (
2225+ x == 0 ,
2226+ - np . inf ,
2227+ ci_series ( x )
21982228 )
21992229
2200- si = jnp .sign (x ) * si
2230+ return si , ci
2231+
2232+ def _sici_asympt (x : Array ):
2233+ # sici asympt valid for x > 4 & x <= 1E9
2234+ s = jnp .sin (x )
2235+ c = jnp .cos (x )
2236+ z = 1.0 / (x * x )
2237+
2238+ # Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
2239+ FN4 = np .array ([
2240+ 4.23612862892216586994E0 ,
2241+ 5.45937717161812843388E0 ,
2242+ 1.62083287701538329132E0 ,
2243+ 1.67006611831323023771E-1 ,
2244+ 6.81020132472518137426E-3 ,
2245+ 1.08936580650328664411E-4 ,
2246+ 5.48900223421373614008E-7 ,
2247+ ], dtype = x .dtype )
2248+ FD4 = np .array ([
2249+ 1 ,
2250+ 8.16496634205391016773E0 ,
2251+ 7.30828822505564552187E0 ,
2252+ 1.86792257950184183883E0 ,
2253+ 1.78792052963149907262E-1 ,
2254+ 7.01710668322789753610E-3 ,
2255+ 1.10034357153915731354E-4 ,
2256+ 5.48900252756255700982E-7 ,
2257+ ], dtype = x .dtype )
2258+ GN4 = np .array ([
2259+ 8.71001698973114191777E-2 ,
2260+ 6.11379109952219284151E-1 ,
2261+ 3.97180296392337498885E-1 ,
2262+ 7.48527737628469092119E-2 ,
2263+ 5.38868681462177273157E-3 ,
2264+ 1.61999794598934024525E-4 ,
2265+ 1.97963874140963632189E-6 ,
2266+ 7.82579040744090311069E-9 ,
2267+ ], dtype = x .dtype )
2268+ GD4 = np .array ([
2269+ 1 ,
2270+ 1.64402202413355338886E0 ,
2271+ 6.66296701268987968381E-1 ,
2272+ 9.88771761277688796203E-2 ,
2273+ 6.22396345441768420760E-3 ,
2274+ 1.73221081474177119497E-4 ,
2275+ 2.02659182086343991969E-6 ,
2276+ 7.82579218933534490868E-9 ,
2277+ ], dtype = x .dtype )
2278+
2279+ FN8 = np .array ([
2280+ 4.55880873470465315206E-1 ,
2281+ 7.13715274100146711374E-1 ,
2282+ 1.60300158222319456320E-1 ,
2283+ 1.16064229408124407915E-2 ,
2284+ 3.49556442447859055605E-4 ,
2285+ 4.86215430826454749482E-6 ,
2286+ 3.20092790091004902806E-8 ,
2287+ 9.41779576128512936592E-11 ,
2288+ 9.70507110881952024631E-14 ,
2289+ ], dtype = x .dtype )
2290+ FD8 = np .array ([
2291+ 1.0 ,
2292+ 9.17463611873684053703E-1 ,
2293+ 1.78685545332074536321E-1 ,
2294+ 1.22253594771971293032E-2 ,
2295+ 3.58696481881851580297E-4 ,
2296+ 4.92435064317881464393E-6 ,
2297+ 3.21956939101046018377E-8 ,
2298+ 9.43720590350276732376E-11 ,
2299+ 9.70507110881952025725E-14 ,
2300+ ], dtype = x .dtype )
2301+ GN8 = np .array ([
2302+ 6.97359953443276214934E-1 ,
2303+ 3.30410979305632063225E-1 ,
2304+ 3.84878767649974295920E-2 ,
2305+ 1.71718239052347903558E-3 ,
2306+ 3.48941165502279436777E-5 ,
2307+ 3.47131167084116673800E-7 ,
2308+ 1.70404452782044526189E-9 ,
2309+ 3.85945925430276600453E-12 ,
2310+ 3.14040098946363334640E-15 ,
2311+ ], dtype = x .dtype )
2312+ GD8 = np .array ([
2313+ 1.0 ,
2314+ 1.68548898811011640017E0 ,
2315+ 4.87852258695304967486E-1 ,
2316+ 4.67913194259625806320E-2 ,
2317+ 1.90284426674399523638E-3 ,
2318+ 3.68475504442561108162E-5 ,
2319+ 3.57043223443740838771E-7 ,
2320+ 1.72693748966316146736E-9 ,
2321+ 3.87830166023954706752E-12 ,
2322+ 3.14040098946363335242E-15 ,
2323+ ], dtype = x .dtype )
2324+
2325+ f4 = jnp .polyval (FN4 , z ) / (x * jnp .polyval (FD4 , z ))
2326+ g4 = z * jnp .polyval (GN4 , z ) / jnp .polyval (GD4 , z )
2327+
2328+ f8 = jnp .polyval (FN8 , z ) / (x * jnp .polyval (FD8 , z ))
2329+ g8 = z * jnp .polyval (GN8 , z ) / jnp .polyval (GD8 , z )
2330+
2331+ mask = x < 8.0
2332+ f = jnp .where (mask , f4 , f8 )
2333+ g = jnp .where (mask , g4 , g8 )
2334+
2335+ si = (np .pi / 2 ) - f * c - g * s
2336+ ci = f * s - g * c
22012337
22022338 return si , ci
22032339
@@ -2213,6 +2349,7 @@ def sici_jvp(primals, tangents):
22132349 tangent_out = (sin_term * t , cos_term * t )
22142350 return primal_out , tangent_out
22152351
2352+
22162353def _expn1 (x : Array , n : Array ) -> Array :
22172354 # exponential integral En
22182355 _c = _lax_const
0 commit comments