Skip to content

Commit 057ccac

Browse files
committed
feat(typing): improve types in arraycontext.fake_numpy
1 parent 4496660 commit 057ccac

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

arraycontext/fake_numpy.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import operator
3131
from abc import ABC, abstractmethod
3232
from dataclasses import dataclass
33-
from typing import TYPE_CHECKING, Any, Literal, cast, overload
33+
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
3434

3535
import numpy as np
3636
from typing_extensions import deprecated
@@ -78,7 +78,7 @@ def __init__(self, array_context: ArrayContext):
7878
def _get_fake_numpy_linalg_namespace(self):
7979
return BaseFakeNumpyLinalgNamespace(self._array_context)
8080

81-
_numpy_math_functions = frozenset({
81+
_numpy_math_functions: ClassVar[frozenset[str]] = frozenset({
8282
# https://numpy.org/doc/stable/reference/routines.math.html
8383

8484
# FIXME: Heads up: not all of these are supported yet.
@@ -560,7 +560,9 @@ def logical_not(self, x: ArrayOrContainerOrScalar, /
560560

561561
# {{{ BaseFakeNumpyLinalgNamespace
562562

563-
def _reduce_norm(actx: ArrayContext, arys: Iterable[ArrayOrScalar], ord: float | None):
563+
def _reduce_norm(actx: ArrayContext,
564+
arys: Iterable[ArrayOrScalar],
565+
ord: float | None) -> ArrayOrScalar:
564566
from functools import reduce
565567
from numbers import Number
566568

@@ -617,7 +619,7 @@ def norm(self,
617619
raise NotImplementedError("only vector norms are implemented")
618620

619621
if ary.size == 0:
620-
return ary.dtype.type(0)
622+
return cast("ScalarLike", ary.dtype.type(0))
621623

622624
from numbers import Number
623625
if ord == 2:

0 commit comments

Comments
 (0)