Skip to content

Commit 378958a

Browse files
authored
Fix GPU test (#190)
By fixing subviewImpl. Also, reeanble parfor test
1 parent 3092132 commit 378958a

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def _gen_tests():
7878
"test_prange27", # Literal return issue
7979
"test_simple01", # Empty shape not failed
8080
"test_kmeans", # List suport
81-
"test_simple14", # Slice shape mismatch
8281
"test_ndarray_fill", # array.fill
8382
"test_fuse_argmin_argmax_max_min", # numpy argmin, argmax
8483
"test_max", # max reduction

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/py_linalg_resolver.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1439,7 +1439,17 @@ py::object subviewImpl(py::capsule context, py::handle src, py::handle offsets,
14391439
for (auto i : llvm::seq(size_t(0), ret.size())) {
14401440
auto dim = builder.createOrFold<mlir::tensor::DimOp>(
14411441
loc, origSrcVal, static_cast<int64_t>(i));
1442-
auto offset = offsetVals[i].get<mlir::Value>();
1442+
auto offset = [&]() -> mlir::Value {
1443+
auto off = offsetVals[i];
1444+
if (off.is<mlir::Value>())
1445+
return off.get<mlir::Value>();
1446+
1447+
auto val = off.get<mlir::Attribute>()
1448+
.cast<mlir::IntegerAttr>()
1449+
.getValue()
1450+
.getSExtValue();
1451+
return builder.create<mlir::arith::ConstantIndexOp>(loc, val);
1452+
}();
14431453
auto size = builder.createOrFold<mlir::arith::SubIOp>(loc, dim, offset);
14441454
ret[i] = size;
14451455
}

0 commit comments

Comments
 (0)