Skip to content

Commit 7b7fc60

Browse files
committed
Add expit function
1 parent 8b4a275 commit 7b7fc60

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,3 +913,29 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
913913
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
914914
)
915915
return xp.sin(y) / y
916+
917+
918+
def expit(x: Array, /, *, xp: ModuleType | None = None) -> Array:
919+
"""
920+
Return the expit function.
921+
922+
The expit function, also known as the logistic sigmoid function.
923+
It is the inverse of the logit function.
924+
925+
Parameters
926+
----------
927+
x : array
928+
Input array.
929+
xp : array_namespace, optional
930+
The standard-compatible namespace for `x`. Default: infer.
931+
932+
Returns
933+
-------
934+
array
935+
An array of the same shape as x. Its entries are expit of the
936+
corresponding entry of x.
937+
"""
938+
if xp is None:
939+
xp = array_namespace(x)
940+
941+
return 1.0 / (1.0 + xp.exp(-x))

tests/test_funcs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
from hypothesis import given
1111
from hypothesis import strategies as st
12+
from scipy import special
1213

1314
from array_api_extra import (
1415
apply_where,
@@ -26,6 +27,7 @@
2627
sinc,
2728
)
2829
from array_api_extra._lib import Backend
30+
from array_api_extra._lib._funcs import expit
2931
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
3032
from array_api_extra._lib._utils._compat import device as get_device
3133
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
@@ -1003,3 +1005,12 @@ def test_device(self, xp: ModuleType, device: Device):
10031005

10041006
def test_xp(self, xp: ModuleType):
10051007
xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))
1008+
1009+
1010+
class TestExpit:
1011+
def test_simple(self, xp: ModuleType):
1012+
x = xp.asarray([2, 3, 4, 5])
1013+
np_x = np.asarray([2, 3, 4, 5])
1014+
actual = expit(x)
1015+
expected = special.expit(np_x)
1016+
xp_assert_close(actual, expected)

0 commit comments

Comments
 (0)