Skip to content

Commit 806a0d3

Browse files
authored
np.arange (#191)
1 parent 378958a commit 806a0d3

File tree

5 files changed

+72
-10
lines changed

5 files changed

+72
-10
lines changed

numba_dpcomp/numba_dpcomp/mlir/linalg_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def linalg_generic(self, inputs, outputs, iterators, maps, body):
143143
def linalg_index(self, dim):
144144
return self._linalg_index(self._context, dim)
145145

146-
def from_elements(self, values, dtype):
146+
def from_elements(self, values, dtype=None):
147147
return self._from_elements(self._context, values, dtype)
148148

149149
def extract(self, value, indices):

numba_dpcomp/numba_dpcomp/mlir/numpy/funcs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,36 @@ def ones_like_impl(builder, arr):
396396
return _init_impl(builder, arr.shape, arr.dtype, 1)
397397

398398

399+
@register_func("numpy.arange", numpy.arange)
400+
def arange_impl(builder, start, stop=None, step=None, dtype=None):
401+
if stop is None:
402+
stop = start
403+
start = 0
404+
405+
if step is None:
406+
step = 1
407+
408+
if dtype is None:
409+
dtype = builder.int64
410+
411+
inc = builder.select(step < 0, 1, -1)
412+
count = (stop - start + step + inc) // step
413+
count = builder.select(count < 0, 0, count)
414+
415+
start = builder.from_elements(start)
416+
step = builder.from_elements(step)
417+
init = builder.init_tensor([count], dtype)
418+
419+
iterators = ["parallel"]
420+
maps = ["(d0) -> (0)", "(d0) -> (0)", "(d0) -> (d0)"]
421+
422+
def body(a, b, c):
423+
i = _linalg_index(0)
424+
return a + b * i
425+
426+
return builder.linalg_generic((start, step), init, iterators, maps, body)
427+
428+
399429
@register_func("numpy.eye", numpy.eye)
400430
def eye_impl(builder, N, M=None, k=0, dtype=None):
401431
if M is None:

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ def _gen_tests():
8282
"test_fuse_argmin_argmax_max_min", # numpy argmin, argmax
8383
"test_max", # max reduction
8484
"test_min", # min reduction
85-
"test_arange", # numpy.arange
85+
"test_arange", # select issue, complex
8686
"test_pi", # np.random.ranf
8787
"test_simple20", # AssertionError not raised
88-
"test_simple24", # numpy.arange
88+
"test_simple24", # getitem with array
8989
"test_0d_array", # numpy prod
9090
"test_argmin", # numpy.argmin
9191
"test_argmax", # numpy.argmax
@@ -104,19 +104,18 @@ def _gen_tests():
104104
"test_simple19", # np.dot unsupported args
105105
"test_no_hoisting_with_member_function_call", # set support
106106
"test_parfor_dtype_type", # dtype cast
107-
"test_tuple3", # numpy.arange
108107
"test_parfor_array_access3", # TypeError: unsupported operand type(s) for -: 'NoneType' and 'NoneType'
109108
"test_preparfor_canonicalize_kws", # array.argsort
110109
"test_parfor_array_access4", # np.dot unsupported args
111110
"test_tuple_concat_with_reverse_slice", # enumerate
112111
"test_reduce", # functools.reduce
113-
"test_two_d_array_reduction", # np.arange
112+
"test_two_d_array_reduction", # 'memref<?x?xi64>' and result type 'memref<?x?xi32>' are cast incompatible
114113
"test_tuple_concat", # tuple concat
115114
"test_two_d_array_reduction_with_float_sizes", # np.array
116-
"test_two_d_array_reduction_reuse", # np.arange
117-
"test_parfor_array_access_lower_slice", # np.arange
115+
"test_two_d_array_reduction_reuse", # 'memref<?x?xi64>' and result type 'memref<?x?xi32>' are cast incompatible
116+
"test_parfor_array_access_lower_slice", # plier.getitem
118117
"test_size_assertion", # AssertionError not raised
119-
"test_parfor_slice18", # np.arange
118+
"test_parfor_slice18", # cast types mismatch
120119
"test_simple12", # complex128
121120
"test_parfor_slice2", # AssertionError not raised
122121
"test_parfor_slice6", # array.transpose

numba_dpcomp/numba_dpcomp/mlir/tests/test_numpy.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,29 @@ def py_func(d):
838838
assert_equal(py_func(a), jit_func(a))
839839

840840

841+
@parametrize_function_variants(
842+
"py_func",
843+
[
844+
"lambda : np.arange(0)",
845+
"lambda : np.arange(1)",
846+
"lambda : np.arange(7)",
847+
"lambda : np.arange(-1)",
848+
"lambda : np.arange(-1,6)",
849+
"lambda : np.arange(-1,6,1)",
850+
"lambda : np.arange(-1,6,2)",
851+
"lambda : np.arange(-1,6,3)",
852+
"lambda : np.arange(6,-1,-1)",
853+
"lambda : np.arange(6,-1,-2)",
854+
"lambda : np.arange(6,-1,-3)",
855+
"lambda : np.arange(5,dtype=np.int32)",
856+
"lambda : np.arange(5,dtype=np.float32)",
857+
],
858+
)
859+
def test_arange(py_func):
860+
jit_func = njit(py_func)
861+
assert_equal(py_func(), jit_func())
862+
863+
841864
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
842865
def test_dtype_param(dtype):
843866
def py_func(dt):

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/py_linalg_resolver.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -970,11 +970,21 @@ static py::object fromElementsImpl(py::capsule context, py::handle values,
970970
auto &ctx = getPyContext(context);
971971
auto &builder = ctx.builder;
972972
auto loc = ctx.loc;
973-
auto type = unwrapType(dtype);
973+
mlir::Type type;
974+
if (!dtype.is_none())
975+
type = unwrapType(dtype);
974976

975977
llvm::SmallVector<mlir::Value> vals(containerSize(values));
976978
containerIterate(values, [&](auto index, py::handle obj) {
977-
vals[index] = ctx.context.unwrapVal(loc, builder, obj, type);
979+
auto val = [&]() -> mlir::Value {
980+
if (type)
981+
return ctx.context.unwrapVal(loc, builder, obj, type);
982+
983+
auto v = ctx.context.unwrapVal(loc, builder, obj);
984+
type = v.getType();
985+
return v;
986+
}();
987+
vals[index] = val;
978988
});
979989

980990
if (vals.empty())

0 commit comments

Comments
 (0)