Skip to content

Commit fbe474d

Browse files
authored
[python] Properly type arange result (#205)
1 parent 18d5b1f commit fbe474d

File tree

3 files changed

+75
-3
lines changed

3 files changed

+75
-3
lines changed

numba_dpcomp/numba_dpcomp/mlir/numpy/funcs.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,16 @@ def ones_like_impl(builder, arr):
396396
return _init_impl(builder, arr.shape, arr.dtype, 1)
397397

398398

399+
_is_np_long64 = numpy.int_ == numpy.int64
400+
401+
402+
def _get_numpy_long(builder):
403+
if _is_np_long64:
404+
return builder.int64
405+
else:
406+
return builder.int32
407+
408+
399409
@register_func("numpy.arange", numpy.arange)
400410
def arange_impl(builder, start, stop=None, step=None, dtype=None):
401411
if stop is None:
@@ -406,7 +416,7 @@ def arange_impl(builder, start, stop=None, step=None, dtype=None):
406416
step = 1
407417

408418
if dtype is None:
409-
dtype = builder.int64
419+
dtype = _get_numpy_long(builder)
410420

411421
inc = builder.select(step < 0, 1, -1)
412422
count = (stop - start + step + inc) // step

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,8 @@ def _gen_tests():
107107
"test_parfor_array_access4", # np.dot unsupported args
108108
"test_tuple_concat_with_reverse_slice", # enumerate
109109
"test_reduce", # functools.reduce
110-
"test_two_d_array_reduction", # 'memref<?x?xi64>' and result type 'memref<?x?xi32>' are cast incompatible
111110
"test_tuple_concat", # enumerate
112111
"test_two_d_array_reduction_with_float_sizes", # np.array
113-
"test_two_d_array_reduction_reuse", # 'memref<?x?xi64>' and result type 'memref<?x?xi32>' are cast incompatible
114112
"test_parfor_array_access_lower_slice", # plier.getitem
115113
"test_size_assertion", # AssertionError not raised
116114
"test_parfor_slice18", # cast types mismatch

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_linalg.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,6 +1616,67 @@ struct LowerTupleCasts : public mlir::OpConversionPattern<plier::CastOp> {
16161616
}
16171617
};
16181618

1619+
static mlir::Value convertTensorElements(mlir::OpBuilder &builder,
1620+
mlir::Location loc, mlir::Value src,
1621+
mlir::ShapedType dstType,
1622+
mlir::Type origSrcElemType,
1623+
mlir::Type origDstElemType) {
1624+
assert(src.getType().isa<mlir::ShapedType>());
1625+
auto srcType = src.getType().cast<mlir::ShapedType>();
1626+
assert(srcType.getRank() == dstType.getRank());
1627+
if (srcType.getElementType() == dstType.getElementType())
1628+
return src;
1629+
1630+
if (srcType.isa<mlir::MemRefType>())
1631+
src = builder.create<mlir::bufferization::ToTensorOp>(loc, src);
1632+
1633+
auto rank = static_cast<unsigned>(srcType.getRank());
1634+
llvm::SmallVector<mlir::Value> shape(rank);
1635+
for (auto i : llvm::seq(0u, rank))
1636+
shape[i] = builder.create<mlir::tensor::DimOp>(loc, src, i);
1637+
1638+
mlir::Value init = builder.create<mlir::linalg::InitTensorOp>(
1639+
loc, shape, dstType.getElementType());
1640+
1641+
auto affineMap =
1642+
mlir::AffineMap::getMultiDimIdentityMap(rank, builder.getContext());
1643+
const mlir::AffineMap maps[] = {
1644+
affineMap,
1645+
affineMap,
1646+
};
1647+
1648+
llvm::SmallVector<mlir::StringRef> iterators(rank, "parallel");
1649+
1650+
auto bodyBuilder = [&](mlir::OpBuilder &b, mlir::Location l,
1651+
mlir::ValueRange args) {
1652+
assert(args.size() == 2);
1653+
auto doSignCast = [&](mlir::Value src, mlir::Type dstType) -> mlir::Value {
1654+
if (src.getType() != dstType)
1655+
return b.create<plier::SignCastOp>(l, dstType, src);
1656+
1657+
return src;
1658+
};
1659+
auto arg = doSignCast(args.front(), origSrcElemType);
1660+
arg = b.create<plier::CastOp>(l, origDstElemType, arg);
1661+
arg = doSignCast(arg, dstType.getElementType());
1662+
b.create<mlir::linalg::YieldOp>(l, arg);
1663+
};
1664+
mlir::Value res =
1665+
builder
1666+
.create<mlir::linalg::GenericOp>(loc, init.getType(), src, init, maps,
1667+
iterators, bodyBuilder)
1668+
.getResult(0);
1669+
1670+
if (dstType.isa<mlir::MemRefType>()) {
1671+
auto memrefType =
1672+
mlir::MemRefType::get(dstType.getShape(), dstType.getElementType());
1673+
res = builder.create<mlir::bufferization::ToMemrefOp>(loc, memrefType, res);
1674+
}
1675+
1676+
rerunScfPipeline(res.getDefiningOp());
1677+
return res;
1678+
}
1679+
16191680
struct LowerTensorCasts : public mlir::OpConversionPattern<plier::CastOp> {
16201681
using OpConversionPattern::OpConversionPattern;
16211682

@@ -1653,6 +1714,9 @@ struct LowerTensorCasts : public mlir::OpConversionPattern<plier::CastOp> {
16531714
value =
16541715
rewriter.createOrFold<plier::SignCastOp>(loc, signlessSrcType, value);
16551716

1717+
value = convertTensorElements(rewriter, loc, value, signlessDstType,
1718+
srcElem, dstElem);
1719+
16561720
bool isSrcMemref = srcType.isa<mlir::MemRefType>();
16571721
bool isDstMemref = dstType.isa<mlir::MemRefType>();
16581722

0 commit comments

Comments
 (0)