@@ -288,35 +288,71 @@ def testExpiDisableJit(self):
288288 self .assertAllClose (result_jit , result_nojit )
289289
290290 def testGammaIncBoundaryValues (self ):
291- dtype = jax .numpy . zeros ( 0 ). dtype # default float dtype.
291+ dtype = jax .dtypes . canonicalize_dtype ( float )
292292 nan = float ('nan' )
293293 inf = float ('inf' )
294294 if jtu .parse_version (scipy .__version__ ) >= (1 , 16 ):
295- samples_slice = slice (None )
295+ a_samples = [0 , 0 , 0 , 1 , nan , 1 , nan , 0 , 1 , 1 , nan ]
296+ x_samples = [0 , 1 , 2 , 0 , 1 , nan , nan , inf , inf , - 1 , inf ]
296297 else :
297298 # disable samples that contradict with scipy/scipy#22441
298- samples_slice = slice (None , - 1 )
299- args_maker = lambda : [np .array ([0 , 0 , 0 , 1 , nan , 1 , nan , 0 , 1 , nan ][samples_slice ]).astype (dtype ),
300- np .array ([0 , 1 , 2 , 0 , 1 , nan , nan , inf , inf , inf ][samples_slice ]).astype (dtype )]
299+ a_samples = [0 , 0 , 0 , 1 , nan , 1 , nan , 0 , 1 , 1 ]
300+ x_samples = [0 , 1 , 2 , 0 , 1 , nan , nan , inf , inf , - 1 ]
301+
302+ args_maker = lambda : (np .array (a_samples , dtype = dtype ), np .array (x_samples , dtype = dtype ))
303+
301304 rtol = 1E-3 if jtu .test_device_matches (["tpu" ]) else 1e-5
302305 self ._CheckAgainstNumpy (lsp_special .gammainc , osp_special .gammainc , args_maker , rtol = rtol )
303306 self ._CompileAndCheck (lsp_special .gammainc , args_maker , rtol = rtol )
304307
305308 def testGammaIncCBoundaryValues (self ):
306- dtype = jax .numpy . zeros ( 0 ). dtype # default float dtype.
309+ dtype = jax .dtypes . canonicalize_dtype ( float )
307310 nan = float ('nan' )
308311 inf = float ('inf' )
309312 if jtu .parse_version (scipy .__version__ ) >= (1 , 16 ):
310- samples_slice = slice (None )
313+ a_samples = [0 , 0 , 0 , 1 , nan , 1 , nan , 0 , 1 , 1 , nan ]
314+ x_samples = [0 , 1 , 2 , 0 , 1 , nan , nan , inf , inf , - 1 , inf ]
311315 else :
312316 # disable samples that contradict with scipy/scipy#22441
313- samples_slice = slice (None , - 1 )
314- args_maker = lambda : [np .array ([0 , 0 , 0 , 1 , nan , 1 , nan , 0 , 1 , 1 , nan ][samples_slice ]).astype (dtype ),
315- np .array ([0 , 1 , 2 , 0 , 1 , nan , nan , inf , inf , - 1 , inf ][samples_slice ]).astype (dtype )]
317+ a_samples = [0 , 0 , 0 , 1 , nan , 1 , nan , 0 , 1 , 1 ]
318+ x_samples = [0 , 1 , 2 , 0 , 1 , nan , nan , inf , inf , - 1 ]
319+
320+ args_maker = lambda : (np .array (a_samples , dtype = dtype ), np .array (x_samples , dtype = dtype ))
321+
316322 rtol = 1E-3 if jtu .test_device_matches (["tpu" ]) else 1e-5
317323 self ._CheckAgainstNumpy (lsp_special .gammaincc , osp_special .gammaincc , args_maker , rtol = rtol )
318324 self ._CompileAndCheck (lsp_special .gammaincc , args_maker , rtol = rtol )
319325
326+ def testBetaIncBoundaryValues (self ):
327+ dtype = jax .dtypes .canonicalize_dtype (float )
328+ fi = jax .numpy .finfo (dtype )
329+ nan = float ('nan' )
330+ inf = float ('inf' )
331+ tiny = fi .tiny
332+ eps = fi .eps
333+ if jtu .parse_version (scipy .__version__ ) >= (1 , 16 ):
334+ # TODO(pearu): enable tiny samples when a fix to scipy/scipy#22682
335+ # will be available
336+ a_samples = [nan , - 0.5 , inf , 0 , eps , 1 , tiny ][:- 1 ]
337+ b_samples = [nan , - 0.5 , inf , 0 , eps , 1 , tiny ][:- 1 ]
338+ elif jtu .parse_version (scipy .__version__ ) >= (1 , 12 ):
339+ # disabled samples that contradict with scipy/scipy#22425
340+ a_samples = [nan , - 0.5 , 0.5 ]
341+ b_samples = [nan , - 0.5 , 0.5 ]
342+ else :
343+ a_samples = [- 0.5 , 0.5 ]
344+ b_samples = [- 0.5 , 0.5 ]
345+ x_samples = [nan , - 0.5 , 0 , 0.5 , 1 , 1.5 ]
346+
347+ a_samples = np .array (a_samples , dtype = dtype )
348+ b_samples = np .array (b_samples , dtype = dtype )
349+ x_samples = np .array (x_samples , dtype = dtype )
350+
351+ args_maker = lambda : np .meshgrid (a_samples , b_samples , x_samples )
352+
353+ rtol = 1E-3 if jtu .test_device_matches (["tpu" ]) else 5e-5
354+ self ._CheckAgainstNumpy (osp_special .betainc , lsp_special .betainc , args_maker , rtol = rtol )
355+ self ._CompileAndCheck (lsp_special .betainc , args_maker , rtol = rtol )
320356
321357if __name__ == "__main__" :
322358 absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments