diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp index 4be0e06fe2a5e..fddd7c51bfbc8 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp @@ -40,11 +40,11 @@ arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter( addConversion([this](FunctionType ty) -> std::optional { SmallVector inputs; if (failed(convertTypes(ty.getInputs(), inputs))) - return std::nullopt; + return nullptr; SmallVector results; if (failed(convertTypes(ty.getResults(), results))) - return std::nullopt; + return nullptr; return FunctionType::get(ty.getContext(), inputs, results); }); diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 9efea066a03c8..28f9061d9873b 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -169,8 +169,9 @@ struct ConvertMemRefAllocation final : OpConversionPattern { std::is_same(), "expected only memref::AllocOp or memref::AllocaOp"); auto currentType = cast(op.getMemref().getType()); - auto newResultType = dyn_cast( - this->getTypeConverter()->convertType(op.getType())); + auto newResultType = + this->getTypeConverter()->template convertType( + op.getType()); if (!newResultType) { return rewriter.notifyMatchFailure( op->getLoc(), @@ -378,7 +379,7 @@ struct ConvertMemRefReinterpretCast final matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MemRefType newTy = - dyn_cast(getTypeConverter()->convertType(op.getType())); + getTypeConverter()->convertType(op.getType()); if (!newTy) { return rewriter.notifyMatchFailure( op->getLoc(), @@ -466,8 +467,8 @@ struct ConvertMemRefSubview final : OpConversionPattern { LogicalResult matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MemRefType newTy = dyn_cast( - getTypeConverter()->convertType(subViewOp.getType())); + MemRefType newTy = + getTypeConverter()->convertType(subViewOp.getType()); if (!newTy) { return rewriter.notifyMatchFailure( subViewOp->getLoc(), @@ -632,14 +633,14 @@ void memref::populateMemRefNarrowTypeEmulationConversions( SmallVector strides; int64_t offset; if (failed(getStridesAndOffset(ty, strides, offset))) - return std::nullopt; + return nullptr; if (!strides.empty() && strides.back() != 1) - return std::nullopt; + return nullptr; auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth, intTy.getSignedness()); if (!newElemTy) - return std::nullopt; + return nullptr; StridedLayoutAttr layoutAttr; // If the offset is 0, we do not need a strided layout as the stride is diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp index bc4535f97acf0..49b71625291db 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -159,7 +159,7 @@ void memref::populateMemRefWideIntEmulationConversions( Type newElemTy = typeConverter.convertType(intTy); if (!newElemTy) - return std::nullopt; + return nullptr; return ty.cloneWith(std::nullopt, newElemTy); }); diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index 540da239fced0..1d6cbfa343ba5 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -203,7 +203,6 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 { // ----- - func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 { %c0 = arith.constant 0 : index %arr = memref.alloc() : memref<40x40xi4> @@ -543,13 +542,15 @@ func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>) // ----- -!colMajor = memref<8x8xi4, strided<[1, 8]>> -func.func @copy_distinct_layouts(%idx : index) -> i4 { - %c0 = arith.constant 0 : index - %arr = memref.alloc() : memref<8x8xi4> - %arr2 = memref.alloc() : !colMajor - // expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}} - memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor - %ld = memref.load %arr2[%c0, %c0] : !colMajor - return %ld : i4 +func.func @alloc_non_contiguous() { + // expected-error @+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}} + %arr = memref.alloc() : memref<8x8xi4, strided<[1, 8]>> + return +} + +// ----- + +// expected-error @+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}} +func.func @argument_non_contiguous(%arg0 : memref<8x8xi4, strided<[1, 8]>>) { + return } diff --git a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir index 65ac5beed0a1d..994e400bd73c1 100644 --- a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir +++ b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s +// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" %s \ +// RUN: --split-input-file --verify-diagnostics | FileCheck %s // Expect no conversions, i32 is supported. // CHECK-LABEL: func @memref_i32 @@ -15,6 +16,8 @@ func.func @memref_i32() { return } +// ----- + // Expect no conversions, f64 is not an integer type. // CHECK-LABEL: func @memref_f32 // CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1> @@ -30,6 +33,8 @@ func.func @memref_f32() { return } +// ----- + // CHECK-LABEL: func @alloc_load_store_i64 // CHECK: [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32> // CHECK-NEXT: [[M:%.+]] = memref.alloc() : memref<4xvector<2xi32>, 1> @@ -45,6 +50,7 @@ func.func @alloc_load_store_i64() { return } +// ----- // CHECK-LABEL: func @alloc_load_store_i64_nontemporal // CHECK: [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32> @@ -60,3 +66,30 @@ func.func @alloc_load_store_i64_nontemporal() { memref.store %c1, %m[%c0] {nontemporal = true} : memref<4xi64, 1> return } + +// ----- + +// Make sure we do not crash on unsupported types. +func.func @alloc_i128() { + // expected-error@+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}} + %m = memref.alloc() : memref<4xi128, 1> + return +} + +// ----- + +func.func @load_i128(%m: memref<4xi128, 1>) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{failed to legalize operation 'memref.load' that was explicitly marked illegal}} + %v = memref.load %m[%c0] : memref<4xi128, 1> + return +} + +// ----- + +func.func @store_i128(%c1: i128, %m: memref<4xi128, 1>) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{failed to legalize operation 'memref.store' that was explicitly marked illegal}} + memref.store %c1, %m[%c0] : memref<4xi128, 1> + return +}