Skip to content

Commit 431e6ae

Browse files
authored
Support reshape with negative val (#188)
1 parent 3987c17 commit 431e6ae

File tree

6 files changed

+70
-25
lines changed

6 files changed

+70
-25
lines changed

azure-pipelines.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
- script: |
4242
call "C:\Miniconda\Scripts\activate"
4343
call cd numba_dpcomp
44-
pytest -n1 -vv --capture=tee-sys -rX
44+
pytest -n1 -vv --capture=tee-sys -rXF
4545
displayName: 'Tests'
4646
4747
- script: |
@@ -126,7 +126,7 @@ jobs:
126126
export OCL_ICD_FILENAMES=libintelocl.so
127127
export SYCL_DEVICE_FILTER=opencl:cpu
128128
export NUMBA_DISABLE_PERFORMANCE_WARNINGS=1
129-
pytest -n1 -vv --capture=tee-sys -rX
129+
pytest -n1 -vv --capture=tee-sys -rXF
130130
displayName: 'Tests'
131131
132132
- script: |
@@ -196,7 +196,7 @@ jobs:
196196
source /usr/local/miniconda/bin/activate
197197
cd numba_dpcomp
198198
conda activate test_env
199-
pytest -n1 -vv --capture=tee-sys -rX
199+
pytest -n1 -vv --capture=tee-sys -rXF
200200
displayName: 'Tests'
201201
202202
- script: |

numba_dpcomp/numba_dpcomp/mlir/linalg_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def __rmul__(self, o):
6565
def __truediv__(self, o):
6666
return self._binop(self._context, self._ssa_val, o, "/")
6767

68+
def __floordiv__(self, o):
69+
return self._binop(self._context, self._ssa_val, o, "//")
70+
6871
def __lt__(self, o):
6972
return self._binop(self._context, self._ssa_val, o, "lt")
7073

numba_dpcomp/numba_dpcomp/mlir/numpy/funcs.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,11 +587,22 @@ def reshape_impl(builder, arg, *new_shape):
587587
if len(new_shape) == 1:
588588
new_shape = new_shape[0]
589589

590-
# TODO: better check
591590
if isinstance(new_shape, tuple):
592-
for s in new_shape:
591+
neg_index = None
592+
for i, s in enumerate(new_shape):
593593
if isinstance(s, int) and s < 0:
594-
return # not supported for now
594+
assert neg_index is None
595+
neg_index = i
596+
if neg_index is not None:
597+
size = 1
598+
for i, s in enumerate(new_shape):
599+
if i != neg_index:
600+
size = size * s
601+
602+
size = size_impl(builder, arg) // size
603+
new_shape = list(new_shape)
604+
new_shape[neg_index] = size
605+
new_shape = tuple(new_shape)
595606

596607
return builder.reshape(arg, new_shape)
597608

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def _gen_tests():
9595
"test_simple18", # np.linalg.svd
9696
"test_linspace", # np.linspace
9797
"test_std", # array.std
98-
"test_reshape_with_neg_one", # unsupported reshape
9998
"test_mvdot", # np.dot unsupported args
10099
"test_array_tuple_concat", # tuple concat
101100
"test_namedtuple1", # namedtuple support
@@ -141,7 +140,6 @@ def _gen_tests():
141140
"test_issue6102", # list support
142141
"test_oversized_tuple_as_arg_to_kernel", # UnsupportedParforsError not raised
143142
"test_issue5942_2", # invalid result
144-
"test_reshape_with_large_neg", # unsupported reshape
145143
"test_parfor_ufunc_typing", # np.isinf
146144
"test_issue_5098", # list support and more
147145
"test_parfor_slice27", # Literal return issue

numba_dpcomp/numba_dpcomp/mlir/tests/test_numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,13 @@ def py_func(a):
10061006
"lambda a: a.reshape((1, a.size))",
10071007
"lambda a: a.reshape(1, a.size)",
10081008
"lambda a: a.reshape((1, a.size, 1))",
1009+
"lambda a: a.reshape((-1, a.size, 1))",
1010+
"lambda a: a.reshape((1, -1, 1))",
1011+
"lambda a: a.reshape((1, a.size, -1))",
10091012
"lambda a: a.reshape(1, a.size, 1)",
1013+
"lambda a: a.reshape(-1, a.size, 1)",
1014+
"lambda a: a.reshape(1, -1, 1)",
1015+
"lambda a: a.reshape(1, a.size, -1)",
10101016
"lambda a: np.reshape(a, a.size)",
10111017
"lambda a: np.reshape(a, (a.size,))",
10121018
"lambda a: np.reshape(a, (a.size,1))",

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/py_linalg_resolver.cpp

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -232,20 +232,6 @@ static llvm::Optional<py::object> getPyLiteral(mlir::Attribute attr) {
232232
return {};
233233
}
234234

235-
static llvm::Optional<py::object> makePyLiteral(mlir::Value val) {
236-
assert(val);
237-
if (auto literal = val.getType().dyn_cast<plier::LiteralType>())
238-
return getPyLiteral(literal.getValue());
239-
240-
if (auto cast = val.getDefiningOp<plier::SignCastOp>())
241-
val = cast.value();
242-
243-
if (auto attr = plier::getConstVal<mlir::Attribute>(val))
244-
return getPyLiteral(attr);
245-
246-
return {};
247-
}
248-
249235
static mlir::Value doCast(mlir::OpBuilder &builder, mlir::Location loc,
250236
mlir::Value val, mlir::Type type) {
251237
if (val.getType() != type)
@@ -284,7 +270,7 @@ struct PyLinalgResolver::Context {
284270
if (auto typevar = type.dyn_cast<plier::TypeVar>())
285271
return createType(typevar.getType());
286272

287-
if (auto literal = makePyLiteral(value))
273+
if (auto literal = makePyLiteral(context, value))
288274
return *literal;
289275

290276
auto ret = var(context, wrapMlir(value));
@@ -400,6 +386,32 @@ struct PyLinalgResolver::Context {
400386

401387
return doCast(builder, loc, unwrapVal(loc, builder, obj), resultType);
402388
}
389+
390+
private:
391+
llvm::Optional<py::object> makePyLiteral(py::capsule context,
392+
mlir::Value val) {
393+
assert(val);
394+
if (auto literal = val.getType().dyn_cast<plier::LiteralType>())
395+
return getPyLiteral(literal.getValue());
396+
397+
if (auto buildTuple = val.getDefiningOp<plier::BuildTupleOp>()) {
398+
auto args = buildTuple.args();
399+
auto count = static_cast<unsigned>(args.size());
400+
py::tuple ret(count);
401+
for (auto i : llvm::seq(0u, count))
402+
ret[i] = createVar(context, args[i]);
403+
404+
return ret;
405+
}
406+
407+
if (auto cast = val.getDefiningOp<plier::SignCastOp>())
408+
val = cast.value();
409+
410+
if (auto attr = plier::getConstVal<mlir::Attribute>(val))
411+
return getPyLiteral(attr);
412+
413+
return {};
414+
}
403415
};
404416

405417
namespace {
@@ -1653,13 +1665,19 @@ static py::object getitemImpl(py::capsule context, py::capsule ssaVal,
16531665
template <typename Op>
16541666
static mlir::Value binopFunc(mlir::Location loc, mlir::OpBuilder &builder,
16551667
mlir::Value lhs, mlir::Value rhs) {
1656-
return builder.create<Op>(loc, lhs, rhs);
1668+
auto lhsVar = doSignCast(builder, loc, lhs);
1669+
auto rhsVar = doSignCast(builder, loc, rhs);
1670+
auto res = builder.create<Op>(loc, lhsVar, rhsVar);
1671+
return doSignCast(builder, loc, res, lhs.getType());
16571672
}
16581673

16591674
template <typename Op>
16601675
static mlir::Value rbinopFunc(mlir::Location loc, mlir::OpBuilder &builder,
16611676
mlir::Value lhs, mlir::Value rhs) {
1662-
return builder.create<Op>(loc, rhs, lhs);
1677+
auto lhsVar = doSignCast(builder, loc, lhs);
1678+
auto rhsVar = doSignCast(builder, loc, rhs);
1679+
auto res = builder.create<Op>(loc, rhsVar, lhsVar);
1680+
return doSignCast(builder, loc, res, lhs.getType());
16631681
}
16641682

16651683
static mlir::Value binopFuncIdiv(mlir::Location loc, mlir::OpBuilder &builder,
@@ -1669,6 +1687,14 @@ static mlir::Value binopFuncIdiv(mlir::Location loc, mlir::OpBuilder &builder,
16691687
return builder.create<mlir::arith::DivFOp>(loc, lhsVar, rhsVar);
16701688
}
16711689

1690+
static mlir::Value binopFFloorDiv(mlir::Location loc, mlir::OpBuilder &builder,
1691+
mlir::Value lhs, mlir::Value rhs) {
1692+
auto lhsVar = doCast(builder, loc, lhs, builder.getF64Type());
1693+
auto rhsVar = doCast(builder, loc, rhs, builder.getF64Type());
1694+
auto res = builder.create<mlir::arith::DivFOp>(loc, lhsVar, rhsVar);
1695+
return builder.create<mlir::math::FloorOp>(loc, res);
1696+
}
1697+
16721698
template <mlir::arith::CmpIPredicate Pred>
16731699
static mlir::Value binopCmpI(mlir::Location loc, mlir::OpBuilder &builder,
16741700
mlir::Value lhs, mlir::Value rhs) {
@@ -1713,6 +1739,7 @@ static py::object binopImpl(py::capsule context, py::capsule ssaVal,
17131739
&rbinopFunc<mlir::arith::SubFOp>},
17141740
{"*", &binopFunc<mlir::arith::MulIOp>, &binopFunc<mlir::arith::MulFOp>},
17151741
{"/", &binopFuncIdiv, &binopFunc<mlir::arith::DivFOp>},
1742+
{"//", &binopFunc<mlir::arith::DivSIOp>, &binopFFloorDiv},
17161743

17171744
{"lt", &binopCmpI<mlir::arith::CmpIPredicate::slt>,
17181745
&binopCmpF<mlir::arith::CmpFPredicate::OLT>},

0 commit comments

Comments
 (0)