Skip to content

Commit ecd0b33

Browse files
Merge pull request #33178 from michaeleliot:sici-2
PiperOrigin-RevId: 834767990
2 parents a737725 + ecb1656 commit ecd0b33

File tree

2 files changed

+158
-9
lines changed

2 files changed

+158
-9
lines changed

jax/_src/scipy/special.py

Lines changed: 146 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,6 +2115,7 @@ def expi_jvp(primals, tangents):
21152115
(x_dot,) = tangents
21162116
return expi(x), jnp.exp(x) / x * x_dot
21172117

2118+
21182119
@custom_derivatives.custom_jvp
21192120
@jit
21202121
def sici(x: ArrayLike) -> tuple[Array, Array]:
@@ -2151,6 +2152,35 @@ def sici(x: ArrayLike) -> tuple[Array, Array]:
21512152
f"Argument `x` to sici must be real-valued. Got dtype {x.dtype}."
21522153
)
21532154

2155+
x_abs = jnp.abs(x)
2156+
2157+
si_series, ci_series = _sici_series(x_abs)
2158+
si_asymp, ci_asymp = _sici_asympt(x_abs)
2159+
si_approx, ci_approx = _sici_approx(x_abs)
2160+
2161+
cond1 = x_abs <= 4
2162+
cond2 = (x_abs > 4) & (x_abs <= 1e9)
2163+
2164+
si = jnp.select([cond1, cond2], [si_series, si_asymp], si_approx)
2165+
ci = jnp.select([cond1, cond2], [ci_series, ci_asymp], ci_approx)
2166+
2167+
si = jnp.sign(x) * si
2168+
ci = jnp.where(isneginf(x), np.nan, ci)
2169+
2170+
return si, ci
2171+
2172+
def _sici_approx(x: Array):
2173+
# sici approximation valid for x >= 1E9
2174+
si = (np.pi / 2) - jnp.cos(x) / x
2175+
ci = jnp.sin(x) / x
2176+
2177+
si = jnp.where(isposinf(x), np.pi / 2, si)
2178+
ci = jnp.where(isposinf(x), 0.0, ci)
2179+
2180+
return si, ci
2181+
2182+
def _sici_series(x: Array):
2183+
# sici series valid for x >= 0 and x <= 4
21542184
def si_series(x):
21552185
# Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
21562186
SN = np.array([-8.39167827910303881427E-11,
@@ -2185,19 +2215,125 @@ def ci_series(x):
21852215
t = x * x
21862216
return np.euler_gamma + jnp.log(x) + t * jnp.polyval(CN, t) / jnp.polyval(CD, t)
21872217

2188-
si = jnp.piecewise(
2189-
jnp.abs(x),
2190-
[x == 0, jnp.isinf(x)],
2191-
[0.0, np.pi/2, si_series]
2218+
si = jnp.where(
2219+
x == 0,
2220+
0.0,
2221+
si_series(x)
21922222
)
21932223

2194-
ci = jnp.piecewise(
2195-
jnp.abs(x),
2196-
[x == 0, isposinf(x), isneginf(x)],
2197-
[-np.inf, 0.0, np.nan, ci_series]
2224+
ci = jnp.where(
2225+
x == 0,
2226+
-np.inf,
2227+
ci_series(x)
21982228
)
21992229

2200-
si = jnp.sign(x) * si
2230+
return si, ci
2231+
2232+
def _sici_asympt(x: Array):
2233+
# sici asympt valid for x > 4 & x <= 1E9
2234+
s = jnp.sin(x)
2235+
c = jnp.cos(x)
2236+
z = 1.0 / (x * x)
2237+
2238+
# Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
2239+
FN4 = np.array([
2240+
4.23612862892216586994E0,
2241+
5.45937717161812843388E0,
2242+
1.62083287701538329132E0,
2243+
1.67006611831323023771E-1,
2244+
6.81020132472518137426E-3,
2245+
1.08936580650328664411E-4,
2246+
5.48900223421373614008E-7,
2247+
], dtype=x.dtype)
2248+
FD4 = np.array([
2249+
1,
2250+
8.16496634205391016773E0,
2251+
7.30828822505564552187E0,
2252+
1.86792257950184183883E0,
2253+
1.78792052963149907262E-1,
2254+
7.01710668322789753610E-3,
2255+
1.10034357153915731354E-4,
2256+
5.48900252756255700982E-7,
2257+
], dtype=x.dtype)
2258+
GN4 = np.array([
2259+
8.71001698973114191777E-2,
2260+
6.11379109952219284151E-1,
2261+
3.97180296392337498885E-1,
2262+
7.48527737628469092119E-2,
2263+
5.38868681462177273157E-3,
2264+
1.61999794598934024525E-4,
2265+
1.97963874140963632189E-6,
2266+
7.82579040744090311069E-9,
2267+
], dtype=x.dtype)
2268+
GD4 = np.array([
2269+
1,
2270+
1.64402202413355338886E0,
2271+
6.66296701268987968381E-1,
2272+
9.88771761277688796203E-2,
2273+
6.22396345441768420760E-3,
2274+
1.73221081474177119497E-4,
2275+
2.02659182086343991969E-6,
2276+
7.82579218933534490868E-9,
2277+
], dtype=x.dtype)
2278+
2279+
FN8 = np.array([
2280+
4.55880873470465315206E-1,
2281+
7.13715274100146711374E-1,
2282+
1.60300158222319456320E-1,
2283+
1.16064229408124407915E-2,
2284+
3.49556442447859055605E-4,
2285+
4.86215430826454749482E-6,
2286+
3.20092790091004902806E-8,
2287+
9.41779576128512936592E-11,
2288+
9.70507110881952024631E-14,
2289+
], dtype=x.dtype)
2290+
FD8 = np.array([
2291+
1.0,
2292+
9.17463611873684053703E-1,
2293+
1.78685545332074536321E-1,
2294+
1.22253594771971293032E-2,
2295+
3.58696481881851580297E-4,
2296+
4.92435064317881464393E-6,
2297+
3.21956939101046018377E-8,
2298+
9.43720590350276732376E-11,
2299+
9.70507110881952025725E-14,
2300+
], dtype=x.dtype)
2301+
GN8 = np.array([
2302+
6.97359953443276214934E-1,
2303+
3.30410979305632063225E-1,
2304+
3.84878767649974295920E-2,
2305+
1.71718239052347903558E-3,
2306+
3.48941165502279436777E-5,
2307+
3.47131167084116673800E-7,
2308+
1.70404452782044526189E-9,
2309+
3.85945925430276600453E-12,
2310+
3.14040098946363334640E-15,
2311+
], dtype=x.dtype)
2312+
GD8 = np.array([
2313+
1.0,
2314+
1.68548898811011640017E0,
2315+
4.87852258695304967486E-1,
2316+
4.67913194259625806320E-2,
2317+
1.90284426674399523638E-3,
2318+
3.68475504442561108162E-5,
2319+
3.57043223443740838771E-7,
2320+
1.72693748966316146736E-9,
2321+
3.87830166023954706752E-12,
2322+
3.14040098946363335242E-15,
2323+
], dtype=x.dtype)
2324+
2325+
f4 = jnp.polyval(FN4, z) / (x * jnp.polyval(FD4, z))
2326+
g4 = z * jnp.polyval(GN4, z) / jnp.polyval(GD4, z)
2327+
2328+
f8 = jnp.polyval(FN8, z) / (x * jnp.polyval(FD8, z))
2329+
g8 = z * jnp.polyval(GN8, z) / jnp.polyval(GD8, z)
2330+
2331+
mask = x < 8.0
2332+
f = jnp.where(mask, f4, f8)
2333+
g = jnp.where(mask, g4, g8)
2334+
2335+
si = (np.pi / 2) - f * c - g * s
2336+
ci = f * s - g * c
22012337

22022338
return si, ci
22032339

@@ -2213,6 +2349,7 @@ def sici_jvp(primals, tangents):
22132349
tangent_out = (sin_term * t, cos_term * t)
22142350
return primal_out, tangent_out
22152351

2352+
22162353
def _expn1(x: Array, n: Array) -> Array:
22172354
# exponential integral En
22182355
_c = _lax_const

tests/lax_scipy_special_functions_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,25 @@ def testSiciEdgeCases(self):
393393
lax_op = lambda x: lsp_special.sici(x)
394394
si_scipy, ci_scipy = scipy_op(x_samples)
395395
si_jax, ci_jax = lax_op(x_samples)
396+
396397
expected_si = np.array([0.0, np.pi/2, -np.pi/2], dtype=dtype)
397398
expected_ci = np.array([-np.inf, 0.0, np.nan], dtype=dtype)
398399
self.assertAllClose(si_jax, si_scipy, atol=1e-6, rtol=1e-6)
399400
self.assertAllClose(ci_jax, ci_scipy, atol=1e-6, rtol=1e-6)
400401
self.assertAllClose(si_jax, expected_si, atol=1e-6, rtol=1e-6)
401402
self.assertAllClose(ci_jax, expected_ci, atol=1e-6, rtol=1e-6)
402403

404+
@jtu.sample_product(
405+
scale=[1, 10, 1e9],
406+
shape=[(5,), (10,)]
407+
)
408+
def testSiciValueRanges(self, scale, shape):
409+
rng = jtu.rand_default(self.rng(), scale=scale)
410+
args_maker = lambda: [rng(shape, jnp.float32)]
411+
rtol = 5e-3 if jtu.test_device_matches(["tpu"]) else 1e-6
412+
self._CheckAgainstNumpy(
413+
osp_special.sici, lsp_special.sici, args_maker, rtol=rtol)
414+
403415
def testSiciRaiseOnComplexInput(self):
404416
samples = jnp.arange(5, dtype=complex)
405417
with self.assertRaisesRegex(ValueError, "Argument `x` to sici must be real-valued."):

0 commit comments

Comments
 (0)