Skip to content

Commit bc67417

Browse files
alexfiklinducer
authored andcommitted
feat(typing): add vdot to BaseFakeNumpyNamespace
1 parent c44f000 commit bc67417

File tree

5 files changed

+19
-7
lines changed

5 files changed

+19
-7
lines changed

arraycontext/fake_numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,11 @@ def inner(a: ArrayOrScalar) -> ArrayOrScalar:
312312
# subclasses. Defining them as abstract methods would define them
313313
# as attributes, making __getattr__ fail to retrieve the intended function.
314314

315+
def vdot(self,
316+
a: ArrayOrContainerOrScalarT,
317+
b: ArrayOrContainerOrScalarT, /
318+
) -> ArrayOrScalar: ...
319+
315320
def broadcast_to(self,
316321
array: ArrayOrContainerOrScalar, /,
317322
shape: tuple[int, ...]

arraycontext/impl/jax/fake_numpy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def stack(self, arrays, axis=0):
140140

141141
# {{{ linear algebra
142142

143-
def vdot(self, x, y, dtype=None):
143+
@override
144+
def vdot(self, a, b, dtype=None):
144145
from arraycontext import rec_multimap_reduce_array_container
145146

146147
def _rec_vdot(ary1, ary2):
@@ -151,7 +152,7 @@ def _rec_vdot(ary1, ary2):
151152

152153
return jnp.vdot(ary1, ary2)
153154

154-
return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
155+
return rec_multimap_reduce_array_container(sum, _rec_vdot, a, b)
155156

156157
# }}}
157158

arraycontext/impl/numpy/fake_numpy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,9 @@ def inner_ravel(ary: ArrayOrScalar) -> ArrayOrScalar:
233233

234234
return rec_map_container(inner_ravel, a)
235235

236-
def vdot(self, x, y):
237-
return rec_multimap_reduce_array_container(sum, np.vdot, x, y)
236+
@override
237+
def vdot(self, a, b):
238+
return rec_multimap_reduce_array_container(sum, np.vdot, a, b)
238239

239240
def any(self, a, /):
240241
return rec_map_reduce_array_container(partial(reduce, np.logical_or),

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from arraycontext.typing import (
6363
Array,
6464
ArrayOrContainerOrScalar,
65+
ArrayOrContainerOrScalarT,
6566
ArrayOrScalar,
6667
)
6768

@@ -187,11 +188,14 @@ def stack(self, arrays, axis=0):
187188

188189
# {{{ linear algebra
189190

190-
def vdot(self, x, y, dtype=None):
191+
@override
192+
def vdot(self,
193+
a: ArrayOrContainerOrScalarT, b: ArrayOrContainerOrScalarT,
194+
dtype: DTypeLike | None = None):
191195
return rec_multimap_reduce_array_container(
192196
sum,
193197
partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue),
194-
x, y)
198+
a, b)
195199

196200
# }}}
197201

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ def min(self,
301301
def absolute(self, a):
302302
return self.abs(a)
303303

304-
def vdot(self, a: Array, b: Array):
304+
@override
305+
def vdot(self, a: ArrayOrContainerOrScalar, b: ArrayOrContainerOrScalar):
305306
return rec_multimap_array_container(pt.vdot, a, b)
306307

307308
# }}}

0 commit comments

Comments
 (0)