Skip to content

Commit aaf7206

Browse files
committed
Allow scalars for actx.tag, document it is best-effort
1 parent 03d28fd commit aaf7206

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

arraycontext/context.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@
125125
)
126126
from warnings import warn
127127

128-
from typing_extensions import Self
128+
from typing_extensions import Self, override
129129

130130
from pytools import memoize_method
131131

@@ -146,7 +146,6 @@
146146
ArrayOrArithContainerOrScalarT,
147147
ArrayOrContainerOrScalar,
148148
ArrayOrContainerOrScalarT,
149-
ArrayOrContainerT,
150149
ContainerOrScalarT,
151150
NumpyOrContainerOrScalar,
152151
ScalarLike,
@@ -217,6 +216,7 @@ def __init__(self) -> None:
217216
def _get_fake_numpy_namespace(self) -> BaseFakeNumpyNamespace:
218217
...
219218

219+
@override
220220
def __hash__(self) -> int:
221221
raise TypeError(f"unhashable type: '{type(self).__name__}'")
222222

@@ -333,12 +333,14 @@ def freeze_thaw(
333333
@abstractmethod
334334
def tag(self,
335335
tags: ToTagSetConvertible,
336-
array: ArrayOrContainerT) -> ArrayOrContainerT:
336+
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
337337
"""If the array type used by the array context is capable of capturing
338338
metadata, return a version of *array* with the *tags* applied. *array*
339339
itself is not modified. When working with array containers, the
340340
tags are applied to each leaf of the container.
341341
342+
Tagging is best-effort. Untaggable types will be returned as-is.
343+
342344
See :ref:`metadata` as well as application-specific metadata types.
343345
344346
.. versionadded:: 2021.2

arraycontext/impl/pytato/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@
7676
Array,
7777
ArrayOrArithContainerOrScalarT,
7878
ArrayOrContainerOrScalarT,
79-
ArrayOrContainerT,
8079
ArrayOrScalar,
8180
ScalarLike,
8281
is_scalar_like,
@@ -1031,8 +1030,8 @@ def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
10311030
@override
10321031
def tag(self,
10331032
tags: ToTagSetConvertible,
1034-
array: ArrayOrContainerT,
1035-
) -> ArrayOrContainerT:
1033+
array: ArrayOrContainerOrScalarT,
1034+
) -> ArrayOrContainerOrScalarT:
10361035
def _tag(ary: Array) -> Array:
10371036
import jax.numpy as jnp
10381037
if isinstance(ary, jnp.ndarray):

0 commit comments

Comments
 (0)