Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2328,7 +2328,22 @@ LogicalResult ExtractSliceOp::verify() {
// Verify result type against inferred type.
RankedTensorType expectedType = ExtractSliceOp::inferResultType(
getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
SliceVerificationResult result = isRankReducedType(expectedType, getType());
SliceVerificationResult result;
if (getSizes().size() != 0) {
bool hasNonCstValue = false;
for (OpFoldResult size : getSizes()) {
std::optional<int64_t> cst = getConstantIntValue(size);
if (!cst) {
hasNonCstValue = true;
break;
}
}
if (hasNonCstValue && llvm::cast<ShapedType>(getType()).hasStaticShape()) {
result = SliceVerificationResult::SizeMismatch;
return produceSliceErrorMsg(result, *this, expectedType);
}
}
result = isRankReducedType(expectedType, getType());
return produceSliceErrorMsg(result, *this, expectedType);
}

Expand Down Expand Up @@ -2700,10 +2715,26 @@ static SliceVerificationResult verifyInsertSliceOp(

/// Verifier for InsertSliceOp.
LogicalResult InsertSliceOp::verify() {
RankedTensorType expectedType;
SliceVerificationResult result =
verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides(), &expectedType);
// insert_slice is the inverse of extract_slice, use the same type
// inference.
RankedTensorType expectedType = ExtractSliceOp::inferResultType(
getType(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
SliceVerificationResult result;
if (getSizes().size() != 0) {
bool hasNonCstValue = false;
for (OpFoldResult size : getSizes()) {
std::optional<int64_t> cst = getConstantIntValue(size);
if (!cst) {
hasNonCstValue = true;
break;
}
}
if (hasNonCstValue && getSourceType().hasStaticShape()) {
result = SliceVerificationResult::SizeMismatch;
return produceSliceErrorMsg(result, *this, expectedType);
}
}
result = isRankReducedType(expectedType, getSourceType());
return produceSliceErrorMsg(result, *this, expectedType);
}

Expand Down
10 changes: 8 additions & 2 deletions mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {

/// Helper function to dispatch an OpFoldResult into `staticVec` if:
/// a) it is an IntegerAttr
/// b) it is a constant integer value
/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
Expand All @@ -54,8 +55,13 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
staticVec.push_back(apInt.getSExtValue());
return;
}
dynamicVec.push_back(v);
staticVec.push_back(ShapedType::kDynamic);
std::optional<int64_t> maybeConstantInt = getConstantIntValue(ofr);
if (!maybeConstantInt) {
dynamicVec.push_back(v);
staticVec.push_back(ShapedType::kDynamic);
} else {
staticVec.push_back(*maybeConstantInt);
}
}

std::pair<int64_t, OpFoldResult>
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
int64_t origSize = originalShape[originalIdx];
// if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
(ShapedType::isDynamic(reducedShape[reducedIdx]) ||
ShapedType::isDynamic(origSize))) {
(ShapedType::isDynamic(origSize) ||
ShapedType::isDynamic(reducedShape[reducedIdx]))) {
reducedIdx++;
continue;
}
Expand All @@ -448,7 +448,7 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
return std::nullopt;
}
// The whole reducedShape must be scanned, otherwise we bail.
if (reducedIdx != reducedRank)
if (reducedIdx != reducedRank && originalRank != 1)
return std::nullopt;
return unusedDims;
}
Expand All @@ -472,8 +472,8 @@ mlir::isRankReducedType(ShapedType originalType,
if (candidateReducedRank > originalRank)
return SliceVerificationResult::RankTooLarge;

auto optionalUnusedDimsMask =
computeRankReductionMask(originalShape, candidateReducedShape);
auto optionalUnusedDimsMask = computeRankReductionMask(
originalShape, candidateReducedShape, /*matchDynamic=*/true);

// Sizes cannot be matched in case empty vector is returned.
if (!optionalUnusedDimsMask)
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,9 @@ func.func @split_at(%shape: tensor<?xindex>, %index: index) -> (tensor<?xindex>,
// CHECK-NEXT: %[[ISNEG:.*]] = arith.cmpi slt, %[[INDEX]], %[[C0]] : index
// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[HEAD:.*]] = tensor.extract_slice %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
// CHECK-NEXT: %[[HEAD:.*]] = tensor.extract_slice %[[SHAPE]][0] [%[[SELECT]]] [1] : tensor<?xindex> to tensor<?xindex>
// CHECK-NEXT: %[[TAIL_SIZE:.*]] = arith.subi %[[RANK]], %[[SELECT]] : index
// CHECK-NEXT: %[[TAIL:.*]] = tensor.extract_slice %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
// CHECK-NEXT: %[[TAIL:.*]] = tensor.extract_slice %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [1] : tensor<?xindex> to tensor<?xindex>
// CHECK-NEXT: return %[[HEAD]], %[[TAIL]] : tensor<?xindex>, tensor<?xindex>
%head, %tail = "shape.split_at"(%shape, %index) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
return %head, %tail : tensor<?xindex>, tensor<?xindex>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ func.func @scatter_test(%values_in: tensor<3x7x5xi32>, %indices : tensor<3x6xi32
// CHECK: [[RESULT_1:%.+]] = scf.for [[ITER_VAR_1:%.+]] = [[C_0_0]] to [[C_6]] step [[C_1_0]] iter_args([[ITER_ARG_1:%.+]] = [[ITER_ARG_0]]) -> (tensor<3x7x5xi32>) {
// CHECK-DAG: [[EXTRACTED:%.+]] = tensor.extract [[INDICES]][[[ITER_VAR_0]], [[ITER_VAR_1]]] : tensor<3x6xi32>
// CHECK-DAG: [[EXTRACTED_CAST:%.+]] = arith.index_cast [[EXTRACTED]] : i32 to index
// CHECK-DAG: [[EXTRACTED_SLICE:%.+]] = tensor.extract_slice [[INPUT]][[[ITER_VAR_0]], [[ITER_VAR_1]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
// CHECK-DAG: [[INSERTED_SLICE:%.+]] = tensor.insert_slice [[EXTRACTED_SLICE]] into [[ITER_ARG_1]][[[ITER_VAR_0]], [[EXTRACTED_CAST]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
// CHECK-DAG: [[EXTRACTED_SLICE:%.+]] = tensor.extract_slice [[INPUT]][[[ITER_VAR_0]], [[ITER_VAR_1]], 0] [1, 1, 5] [1, 1, 1] : tensor<3x6x5xi32> to tensor<1x1x5xi32>
// CHECK-DAG: [[INSERTED_SLICE:%.+]] = tensor.insert_slice [[EXTRACTED_SLICE]] into [[ITER_ARG_1]][[[ITER_VAR_0]], [[EXTRACTED_CAST]], 0] [1, 1, 5] [1, 1, 1] : tensor<1x1x5xi32> into tensor<3x7x5xi32>
// CHECK: scf.yield [[INSERTED_SLICE]] : tensor<3x7x5xi32>
// CHECK: }
// CHECK: scf.yield [[RESULT_1]] : tensor<3x7x5xi32>
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.pad %[[ARG0]] low[1, 3] high[2, 4] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
Expand Down Expand Up @@ -501,7 +501,7 @@ func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.pad %[[ARG0]] low[1, 3] high[2, 4] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
%1 = arith.constant dense<42.0> : tensor<f32>
Expand All @@ -519,26 +519,26 @@ func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.pad %[[ARG0]] low[1, 3] high[2, 4] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
%1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}

func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<2x9xf32>) {
%0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
// TODO: Output contains multiple "arith.constant 1 : index".
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
// CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: tensor.pad %[[ARG0]] low[-1, 3] high[2, 4] {
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
// CHECK: } : tensor<1x2xf32> to tensor<2x9xf32>
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<2x9xf32>)
return %1 : tensor<2x9xf32>
}

// -----
Expand Down
17 changes: 7 additions & 10 deletions mlir/test/Dialect/ArmSME/vector-legalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,8 @@ func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memre
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
// CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
// CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
// CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[READ_SUBVIEW]] (d0, d1) -> (d1, d0) : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<4x?xf32, strided<[1, ?], offset: ?>>
// CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] {in_bounds = [true, false]} : memref<4x?xf32, strided<[1, ?], offset: ?>>, vector<4x[8]xf32>
// CHECK-NEXT: return %[[LEGAL_READ]]
%pad = arith.constant 0.0 : f32
%illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
Expand All @@ -434,11 +433,10 @@ func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memre
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>
func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index, %memref: memref<?x?xf32>, %a: index, %b: index) -> vector<4x[8]xf32> {
// CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
// CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[READ_SUBVIEW]]
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
// CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
// CHECK-SAME: %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
// CHECK-SAME: %[[MASK]] {in_bounds = [true, false]} : memref<4x?xf32, strided<[1, ?], offset: ?>>, vector<4x[8]xf32>
// CHECK-NEXT: return %[[LEGAL_READ]]
%pad = arith.constant 0.0 : f32
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
Expand All @@ -453,8 +451,7 @@ func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>
func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
// CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
// CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[READ_SUBVIEW]]
// CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
// CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
// CHECK-NEXT: return %[[EXT_TYPE]]
Expand Down Expand Up @@ -514,7 +511,7 @@ func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector

// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
// CHECK: vector.transfer_read {{.*}} : memref<1x?xf32, {{.*}}>, vector<1x[4]xf32>
// CHECK-NOT: vector.shape_cast
%pad = arith.constant 0.0 : f32
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
Expand All @@ -526,7 +523,7 @@ func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: m

// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
// CHECK: vector.transfer_read {{.*}} : memref<1x?xf32, {{.*}}>, vector<1x[4]xf32>
// CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
%pad = arith.constant 0.0 : f32
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/GPU/decompose-memrefs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func.func @decompose_load(%arg0 : memref<?x?x?xf32>) {
// CHECK: gpu.launch
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [2, 2, 2], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -118,7 +118,7 @@ func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) {
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[STRIDES]]#0]
// CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#1]
// CHECK: %[[IDX2:.*]] = affine.apply #[[MAP2]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[IDX]], %[[IDX1]], 4]
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [2, 2, 2], strides: [%[[IDX]], %[[IDX1]], 4]
// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> ()
func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
%c0 = arith.constant 0 : index
Expand Down
Loading
Loading