Skip to content

Commit 0e8d06f

Browse files
Fix warnings from mapped_aval being made private
1 parent a29b06a commit 0e8d06f

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

equinox/_enum.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,7 @@ def is_traced(self) -> bool:
204204

205205
if TYPE_CHECKING:
206206
import enum
207-
from typing import ClassVar
208-
from typing_extensions import Self
207+
from typing import ClassVar, Self
209208

210209
class _Sequence(type):
211210
def __getitem__(cls, item) -> str: ...

equinox/internal/_primitive.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,22 +345,25 @@ def _vprim_impl(*inputs, prim, __axis_size, __axis_name, __batch_axes, params):
345345
return impl(*inputs)
346346

347347

348+
if hasattr(jax.extend.core, "mapped_aval"):
349+
_mapped_aval = jax.extend.core.mapped_aval # pyright: ignore[reportAttributeAccessIssue]
350+
else:
351+
_mapped_aval = jax.core.mapped_aval
352+
if hasattr(jax.extend.core, "unmapped_aval"):
353+
_unmapped_aval = jax.extend.core.unmapped_aval # pyright: ignore[reportAttributeAccessIssue,reportAssignmentType]
354+
else:
355+
_unmapped_aval = jax.core.unmapped_aval # pyright: ignore[reportAssignmentType]
348356
if jax.__version_info__ >= (0, 5, 1):
357+
_old_unmapped_aval = _unmapped_aval
349358

350359
def _unmapped_aval(axis_size, axis_name, axis, aval):
351360
del axis_name
352-
return jax.core.unmapped_aval(axis_size, axis, aval) # pyright: ignore[reportCallIssue]
353-
354-
else:
355-
# signature (axis_size, axis_name, axis, aval)
356-
_unmapped_aval = jax.core.unmapped_aval # pyright: ignore[reportAssignmentType]
361+
return _old_unmapped_aval(axis_size, axis, aval) # pyright: ignore[reportCallIssue]
357362

358363

359364
def _vprim_abstract_eval(*inputs, prim, __axis_size, __axis_name, __batch_axes, params):
360365
assert len(inputs) == len(__batch_axes)
361-
inputs = [
362-
jax.core.mapped_aval(__axis_size, b, x) for x, b in zip(inputs, __batch_axes)
363-
]
366+
inputs = [_mapped_aval(__axis_size, b, x) for x, b in zip(inputs, __batch_axes)]
364367
abstract_eval = _vprim_abstract_eval_registry[prim]
365368
outs = abstract_eval(*inputs, **dict(params))
366369
outs = [_unmapped_aval(__axis_size, __axis_name, 0, x) for x in outs]

0 commit comments

Comments
 (0)