Skip to content

Commit 4ada912

Browse files
authored
Validate memref element type in mapArrayType (#156)
1 parent 1cde2ad commit 4ada912

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_linalg.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,11 @@ static mlir::Type mapArrayType(mlir::MLIRContext &ctx,
179179
desc->layout == ArrayLayout::A) {
180180
if (auto type =
181181
conveter.convertType(plier::PyType::get(&ctx, desc->name))) {
182-
llvm::SmallVector<int64_t> shape(desc->dims,
183-
mlir::ShapedType::kDynamicSize);
184-
return mlir::MemRefType::get(shape, type);
182+
if (mlir::BaseMemRefType::isValidElementType(type)) {
183+
llvm::SmallVector<int64_t> shape(desc->dims,
184+
mlir::ShapedType::kDynamicSize);
185+
return mlir::MemRefType::get(shape, type);
186+
}
185187
}
186188
}
187189
}

0 commit comments

Comments
 (0)