Skip to content

Commit ff16029

Browse files
kaushikcfdinducer
authored andcommitted
Implements NumpyArrayContext
1 parent 9827060 commit ff16029

File tree

4 files changed

+384
-2
lines changed

4 files changed

+384
-2
lines changed

arraycontext/__init__.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
tag_axes,
7979
)
8080
from .impl.jax import EagerJAXArrayContext
81+
from .impl.numpy import NumpyArrayContext
8182
from .impl.pyopencl import PyOpenCLArrayContext
8283
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
8384
from .loopy import make_loopy_program
@@ -91,63 +92,174 @@
9192

9293

9394
__all__ = (
95+
"Array",
96+
"Array",
97+
"Array",
98+
"Array",
9499
"Array",
95100
"Array",
96101
"ArrayContainer",
102+
"ArrayContainer",
103+
"ArrayContainer",
104+
"ArrayContainerT",
97105
"ArrayContainerT",
106+
"ArrayContainerT",
107+
"ArrayContext",
98108
"ArrayContext",
109+
"ArrayContext",
110+
"ArrayOrContainer",
99111
"ArrayOrContainer",
112+
"ArrayOrContainer",
113+
"ArrayOrContainerOrScalar",
114+
"ArrayOrContainerOrScalar",
100115
"ArrayOrContainerOrScalar",
101116
"ArrayOrContainerOrScalarT",
117+
"ArrayOrContainerOrScalarT",
118+
"ArrayOrContainerOrScalarT",
119+
"ArrayOrContainerT",
120+
"ArrayOrContainerT",
102121
"ArrayOrContainerT",
103122
"ArrayT",
123+
"ArrayT",
124+
"ArrayT",
125+
"CommonSubexpressionTag",
104126
"CommonSubexpressionTag",
127+
"CommonSubexpressionTag",
128+
"EagerJAXArrayContext",
105129
"EagerJAXArrayContext",
130+
"EagerJAXArrayContext",
131+
"ElementwiseMapKernelTag",
106132
"ElementwiseMapKernelTag",
133+
"ElementwiseMapKernelTag",
134+
"NotAnArrayContainerError",
135+
"NotAnArrayContainerError",
107136
"NotAnArrayContainerError",
137+
"NumpyArrayContext",
108138
"PyOpenCLArrayContext",
139+
"PyOpenCLArrayContext",
140+
"PyOpenCLArrayContext",
141+
"PytatoJAXArrayContext",
142+
"PytatoJAXArrayContext",
109143
"PytatoJAXArrayContext",
110144
"PytatoPyOpenCLArrayContext",
145+
"PytatoPyOpenCLArrayContext",
146+
"PytatoPyOpenCLArrayContext",
147+
"PytestArrayContextFactory",
111148
"PytestArrayContextFactory",
149+
"PytestArrayContextFactory",
150+
"PytestPyOpenCLArrayContextFactory",
112151
"PytestPyOpenCLArrayContextFactory",
152+
"PytestPyOpenCLArrayContextFactory",
153+
"Scalar",
154+
"Scalar",
155+
"Scalar",
156+
"Scalar",
113157
"Scalar",
114158
"Scalar",
115159
"ScalarLike",
160+
"ScalarLike",
161+
"ScalarLike",
162+
"dataclass_array_container",
163+
"dataclass_array_container",
116164
"dataclass_array_container",
117165
"deserialize_container",
166+
"deserialize_container",
167+
"deserialize_container",
168+
"flat_size_and_dtype",
118169
"flat_size_and_dtype",
170+
"flat_size_and_dtype",
171+
"flatten",
172+
"flatten",
119173
"flatten",
120174
"freeze",
175+
"freeze",
176+
"freeze",
177+
"from_numpy",
121178
"from_numpy",
179+
"from_numpy",
180+
"get_container_context_opt",
122181
"get_container_context_opt",
182+
"get_container_context_opt",
183+
"get_container_context_recursively",
123184
"get_container_context_recursively",
185+
"get_container_context_recursively",
186+
"get_container_context_recursively_opt",
187+
"get_container_context_recursively_opt",
124188
"get_container_context_recursively_opt",
125189
"is_array_container",
190+
"is_array_container",
191+
"is_array_container",
192+
"is_array_container_type",
193+
"is_array_container_type",
126194
"is_array_container_type",
127195
"make_loopy_program",
196+
"make_loopy_program",
197+
"make_loopy_program",
198+
"map_array_container",
128199
"map_array_container",
200+
"map_array_container",
201+
"map_reduce_array_container",
129202
"map_reduce_array_container",
203+
"map_reduce_array_container",
204+
"mapped_over_array_containers",
130205
"mapped_over_array_containers",
206+
"mapped_over_array_containers",
207+
"multimap_array_container",
208+
"multimap_array_container",
131209
"multimap_array_container",
132210
"multimap_reduce_array_container",
211+
"multimap_reduce_array_container",
212+
"multimap_reduce_array_container",
213+
"multimapped_over_array_containers",
133214
"multimapped_over_array_containers",
215+
"multimapped_over_array_containers",
216+
"outer",
217+
"outer",
134218
"outer",
135219
"pytest_generate_tests_for_array_contexts",
220+
"pytest_generate_tests_for_array_contexts",
221+
"pytest_generate_tests_for_array_contexts",
222+
"pytest_generate_tests_for_pyopencl_array_context",
223+
"pytest_generate_tests_for_pyopencl_array_context",
136224
"pytest_generate_tests_for_pyopencl_array_context",
137225
"rec_map_array_container",
226+
"rec_map_array_container",
227+
"rec_map_array_container",
228+
"rec_map_reduce_array_container",
138229
"rec_map_reduce_array_container",
230+
"rec_map_reduce_array_container",
231+
"rec_multimap_array_container",
139232
"rec_multimap_array_container",
233+
"rec_multimap_array_container",
234+
"rec_multimap_reduce_array_container",
140235
"rec_multimap_reduce_array_container",
236+
"rec_multimap_reduce_array_container",
237+
"register_multivector_as_array_container",
238+
"register_multivector_as_array_container",
141239
"register_multivector_as_array_container",
142240
"serialize_container",
241+
"serialize_container",
242+
"serialize_container",
143243
"stringify_array_container_tree",
144244
"tag_axes",
245+
"tag_axes",
246+
"tag_axes",
145247
"thaw",
248+
"thaw",
249+
"thaw",
250+
"to_numpy",
146251
"to_numpy",
252+
"to_numpy",
253+
"unflatten",
254+
"unflatten",
147255
"unflatten",
148256
"with_array_context",
257+
"with_array_context",
258+
"with_array_context",
259+
"with_container_arithmetic",
260+
"with_container_arithmetic",
149261
"with_container_arithmetic"
150-
)
262+
)
151263

152264

153265
# {{{ deprecation handling

arraycontext/container/arithmetic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
"""
3434

3535
from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union
36-
from warnings import warn
3736

3837
import numpy as np
3938

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
.. currentmodule:: arraycontext
3+
4+
5+
A mod :`numpy`-based array context.
6+
7+
.. autoclass:: NumpyArrayContext
8+
"""
9+
__copyright__ = """
10+
Copyright (C) 2021 University of Illinois Board of Trustees
11+
"""
12+
13+
__license__ = """
14+
Permission is hereby granted, free of charge, to any person obtaining a copy
15+
of this software and associated documentation files (the "Software"), to deal
16+
in the Software without restriction, including without limitation the rights
17+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18+
copies of the Software, and to permit persons to whom the Software is
19+
furnished to do so, subject to the following conditions:
20+
21+
The above copyright notice and this permission notice shall be included in
22+
all copies or substantial portions of the Software.
23+
24+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
30+
THE SOFTWARE.
31+
"""
32+
33+
from typing import Dict, Sequence, Union
34+
35+
import numpy as np
36+
37+
import loopy as lp
38+
from pytools.tag import Tag
39+
40+
from arraycontext.context import ArrayContext
41+
42+
43+
class NumpyArrayContext(ArrayContext):
44+
"""
45+
A :class:`ArrayContext` that uses :mod:`numpy.ndarray` to represent arrays
46+
47+
48+
.. automethod:: __init__
49+
"""
50+
def __init__(self):
51+
super().__init__()
52+
self._loopy_transform_cache: \
53+
Dict[lp.TranslationUnit, lp.TranslationUnit] = {}
54+
55+
self.array_types = (np.ndarray,)
56+
57+
def _get_fake_numpy_namespace(self):
58+
from .fake_numpy import NumpyFakeNumpyNamespace
59+
return NumpyFakeNumpyNamespace(self)
60+
61+
# {{{ ArrayContext interface
62+
63+
def clone(self):
64+
return type(self)()
65+
66+
def empty(self, shape, dtype):
67+
return np.empty(shape, dtype=dtype)
68+
69+
def zeros(self, shape, dtype):
70+
return np.zeros(shape, dtype)
71+
72+
def from_numpy(self, np_array: np.ndarray):
73+
# Uh oh...
74+
return np_array
75+
76+
def to_numpy(self, array):
77+
# Uh oh...
78+
return array
79+
80+
def call_loopy(self, t_unit, **kwargs):
81+
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
82+
try:
83+
t_unit = self._loopy_transform_cache[t_unit]
84+
except KeyError:
85+
orig_t_unit = t_unit
86+
t_unit = self.transform_loopy_program(t_unit)
87+
self._loopy_transform_cache[orig_t_unit] = t_unit
88+
del orig_t_unit
89+
90+
_, result = t_unit(**kwargs)
91+
92+
return result
93+
94+
def freeze(self, array):
95+
return array
96+
97+
def thaw(self, array):
98+
return array
99+
100+
# }}}
101+
102+
def transform_loopy_program(self, t_unit):
103+
raise ValueError("NumpyArrayContext does not implement "
104+
"transform_loopy_program. Sub-classes are supposed "
105+
"to implement it.")
106+
107+
def tag(self, tags: Union[Sequence[Tag], Tag], array):
108+
# Numpy doesn't support tagging
109+
return array
110+
111+
def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array):
112+
return array
113+
114+
def einsum(self, spec, *args, arg_names=None, tagged=()):
115+
return np.einsum(spec, *args)
116+
117+
@property
118+
def permits_inplace_modification(self):
119+
return True
120+
121+
@property
122+
def supports_nonscalar_broadcasting(self):
123+
return True
124+
125+
@property
126+
def permits_advanced_indexing(self):
127+
return True

0 commit comments

Comments
 (0)