@@ -303,15 +303,16 @@ def body_fn(vals):
303303
304304def igamma_impl (a , x , * , dtype ):
305305 is_nan = bitwise_or (_isnan (a ), _isnan (x ))
306- x_is_zero = eq (x , _const (x , 0 ))
307306 x_is_infinity = eq (x , _const (x , float ('inf' )))
308- domain_error = bitwise_or (lt (x , _const (x , 0 )), le (a , _const (a , 0 )))
309- use_igammac = bitwise_and (gt (x , _const (x , 1 )), gt (x , a ))
307+ a_is_zero = eq (a , _const (a , 0 ))
308+ x_is_zero = eq (x , _const (x , 0 ))
309+ domain_error = _reduce (bitwise_or , [lt (x , _const (x , 0 )), lt (a , _const (a , 0 )), bitwise_and (a_is_zero , x_is_zero )])
310+
311+ use_igammac = bitwise_and (ge (x , _const (x , 1 )), gt (x , a ))
310312 ax = a * log (x ) - x - lgamma (a )
311313 underflow = lt (ax , - log (dtypes .finfo (dtype ).max ))
312314 ax = exp (ax )
313- enabled = bitwise_not (
314- _reduce (bitwise_or ,[x_is_zero , domain_error , underflow , is_nan ]))
315+ enabled = bitwise_not (_reduce (bitwise_or , [x_is_zero , domain_error , underflow , is_nan , x_is_infinity ]))
315316
316317 output = select (
317318 use_igammac ,
@@ -323,8 +324,7 @@ def igamma_impl(a, x, *, dtype):
323324 )
324325 output = select (x_is_zero , full_like (a , 0 ), output )
325326 output = select (x_is_infinity , full_like (a , 1 ), output )
326- output = select (bitwise_or (domain_error , is_nan ),
327- full_like (a , float ('nan' )), output )
327+ output = select (domain_error , full_like (a , float ('nan' )), output )
328328 return output
329329
330330def _igammac_continued_fraction (ax , x , a , enabled , dtype , mode ):
@@ -433,22 +433,26 @@ def body_fn(vals):
433433 raise ValueError (f"Invalid mode: { mode } " )
434434
435435def igammac_impl (a , x , * , dtype ):
436- out_of_range = bitwise_or (le (x , _const (x , 0 )), le (a , _const (a , 0 )))
436+ is_nan = bitwise_or (_isnan (a ), _isnan (x ))
437+ a_is_zero = eq (a , _const (a , 0 ))
438+ x_is_zero = eq (x , _const (x , 0 ))
439+ x_is_infinity = eq (x , _const (x , float ('inf' )))
440+ domain_error = _reduce (bitwise_or , [lt (x , _const (x , 0 )), lt (a , _const (a , 0 )), bitwise_and (a_is_zero , x_is_zero )])
437441 use_igamma = bitwise_or (lt (x , _const (x , 1 )), lt (x , a ))
438442 ax = a * log (x ) - x - lgamma (a )
439443 underflow = lt (ax , - log (dtypes .finfo (dtype ).max ))
440- enabled = bitwise_not (bitwise_or ( out_of_range , underflow ))
444+ enabled = bitwise_not (_reduce ( bitwise_or , [ domain_error , underflow , is_nan , x_is_infinity , a_is_zero ] ))
441445 ax = exp (ax )
442446
443447 igamma_call = _igamma_series (ax , x , a , bitwise_and (enabled , use_igamma ),
444448 dtype , IgammaMode .VALUE )
445449 igammac_cf_call = _igammac_continued_fraction (ax , x , a ,
446450 bitwise_and (enabled , bitwise_not (use_igamma )), dtype , IgammaMode .VALUE )
447451
448- result = select (use_igamma , _const (a , 1 ) - igamma_call , igammac_cf_call )
449- x_is_infinity = eq ( x , _const ( x , float ( 'inf' )) )
450- result = select (x_is_infinity , full_like (result , 0 ), result )
451- return select ( out_of_range , full_like ( a , 1 ), result )
452+ output = select (use_igamma , _const (a , 1 ) - igamma_call , igammac_cf_call )
453+ output = select ( bitwise_or ( x_is_infinity , a_is_zero ), full_like ( output , 0 ), output )
454+ output = select (domain_error , full_like (a , float ( 'nan' )), output )
455+ return output
452456
453457def igamma_grad_a_impl (a , x , * , dtype ):
454458 is_nan = bitwise_or (_isnan (a ), _isnan (x ))
0 commit comments