Skip to content

Commit 59bdb9c

Browse files
committed
nit
1 parent 970dd33 commit 59bdb9c

File tree

3 files changed

+10
-16
lines changed

3 files changed

+10
-16
lines changed

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ reportUnreachable = false
238238
reportUnusedParameter = false
239239
# cyclic imports inside function bodies
240240
reportImportCycles = false
241+
# PyRight can't trace types in lambdas
242+
reportUnknownLambdaType = false
241243

242244
executionEnvironments = [
243245
{ root = "tests", reportPrivateUsage = false },

src/array_api_extra/_lib/_funcs.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -567,23 +567,15 @@ def isclose(
567567
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
568568
if a_inexact or b_inexact:
569569
# prevent warnings on numpy and dask on inf - inf
570-
meta_xp = meta_namespace(a, b, xp=xp)
571-
572-
def where_inf(a: Array, b: Array) -> Array:
573-
return (
574-
meta_xp.isinf(a)
575-
& meta_xp.isinf(b)
576-
& (meta_xp.sign(a) == meta_xp.sign(b))
577-
)
578-
579-
def where_not_inf(a: Array, b: Array) -> Array:
580-
# Note: inf <= inf is True!
581-
return meta_xp.abs(a - b) <= (atol + rtol * meta_xp.abs(b))
582-
570+
mxp = meta_namespace(a, b, xp=xp)
583571
out = apply_where(
584-
xp.isinf(a) | xp.isinf(b), where_inf, where_not_inf, a, b, xp=xp
572+
xp.isinf(a) | xp.isinf(b),
573+
lambda a, b: mxp.isinf(a) & mxp.isinf(b) & (mxp.sign(a) == mxp.sign(b)), # pyright: ignore[reportUnknownArgumentType]
574+
# Note: inf <= inf is True!
575+
lambda a, b: mxp.abs(a - b) <= (atol + rtol * mxp.abs(b)), # pyright: ignore[reportUnknownArgumentType]
576+
*(a, b),
577+
xp=xp,
585578
)
586-
587579
if equal_nan:
588580
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
589581
return out

0 commit comments

Comments
 (0)