diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 5b6a90f806bed..e42fd5d2ce13c 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -313,7 +313,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform", } def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform", - [AllElementTypesMatch<["value", "output"]>, + [AllElementTypesMatch<["value", "output"]>, DestinationStyleOpInterface, DeclareOpInterfaceMethods scf::ValueVector { + auto context = builder.getContext(); Value tileHIter = ivs[0]; Value tileWIter = ivs[1]; Value NIter = ivs[2]; @@ -740,29 +741,41 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, FIter, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1); - TransformMapKeyTy key = {m, r}; - int64_t retRows = 1; - int64_t retCols = 1; - int64_t leftScalarFactor = 1; - int64_t rightScalarFactor = 1; + const TransformMapKeyTy key = {m, r}; + const TransformMatrix &AMatrix = AMatrices.at(key); + const TransformMatrix &ATMatrix = ATMatrices.at(key); + int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) * + (leftTransform ? ATMatrix.scalarFactor : 1); + int64_t retCols = rightTransform ? AMatrix.cols : 1; + int64_t retRows = leftTransform ? ATMatrix.rows : 1; + Value matmulRetValue = extractValue; Value zero = builder.create( loc, rewriter.getZeroAttr(elementType)); - if (leftTransform) { - // Get constant transform matrix AT. - auto it = ATMatrices.find(key); - if (it == ATMatrices.end()) - return {}; - const TransformMatrix &ATMatrix = it->second; - leftScalarFactor = ATMatrix.scalarFactor; - retRows = ATMatrix.rows; + auto affineMap = + AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); + Value heightOffset = + builder.create(loc, affineMap, tileHIter); + Value widthOffset = + builder.create(loc, affineMap, tileWIter); + + Value outInitVal = + extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset, + widthOffset, retRows, retCols, + /*loopNorFIdx=*/0, + /*loopCorFIdx=*/3, /*heightIdx=*/1, + /*widthIdx=*/2); + if (leftTransform) { auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); - auto empty = - builder - .create(loc, matmulType.getShape(), elementType) - .getResult(); - auto init = builder.create(loc, zero, empty).getResult(0); + Value init = outInitVal; + if (rightTransform || scalarFactor != 1) { + auto empty = builder + .create(loc, matmulType.getShape(), + elementType) + .getResult(); + init = builder.create(loc, zero, empty).getResult(0); + } Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType); // Multiply AT x m. @@ -772,21 +785,16 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, } if (rightTransform) { - // Get constant transform matrix T. - auto it = AMatrices.find(key); - if (it == AMatrices.end()) - return {}; - const TransformMatrix &AMatrix = it->second; - - rightScalarFactor = AMatrix.scalarFactor; auto matmulType = RankedTensorType::get({retRows, AMatrix.cols}, elementType); - retCols = AMatrix.cols; - auto empty = - builder - .create(loc, matmulType.getShape(), elementType) - .getResult(); - auto init = builder.create(loc, zero, empty).getResult(0); + Value init = outInitVal; + if (scalarFactor != 1) { + auto empty = builder + .create(loc, matmulType.getShape(), + elementType) + .getResult(); + init = builder.create(loc, zero, empty).getResult(0); + } Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType); // Multiply y = (AT x m) x A. @@ -795,48 +803,36 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, matmulRetValue = matmulOp.getResult(0); } - if (leftScalarFactor * rightScalarFactor != 1) { - // Multiply scalar factor. - Value scalarFactor = builder.create( - loc, - FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor)); + if (scalarFactor != 1) { + // Multiply by scalar factor and add outInitVal. + Value scalarFactorValue = builder.create( + loc, FloatAttr::get(elementType, scalarFactor)); auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); - auto init = builder.create(loc, matmulType.getShape(), - elementType); - auto identityAffineMap = rewriter.getMultiDimIdentityMap(2); SmallVector affineMaps = { - AffineMap::get(2, 0, init.getContext()), identityAffineMap}; - auto broadcastedScalar = + AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap}; + + matmulRetValue = rewriter .create( - loc, matmulType, ValueRange{scalarFactor}, ValueRange{init}, - affineMaps, + loc, matmulType, + ValueRange{scalarFactorValue, matmulRetValue}, + ValueRange{outInitVal}, affineMaps, llvm::ArrayRef{ utils::IteratorType::parallel, utils::IteratorType::parallel}, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); + auto mulf = nestedBuilder.create( + nestedLoc, args[0], args[1]); + auto addf = nestedBuilder.create( + nestedLoc, mulf.getResult(), args[2]); + nestedBuilder.create(nestedLoc, + addf.getResult()); }) .getResult(0); - - matmulRetValue = builder - .create( - loc, matmulType, - ValueRange{broadcastedScalar, matmulRetValue}, - ValueRange{init}) - .getResult(0); } - auto context = builder.getContext(); - auto affineMap = - AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); - Value heightOffset = - builder.create(loc, affineMap, tileHIter); - Value widthOffset = - builder.create(loc, affineMap, tileWIter); - // Insert (H, W) to (N, H, W, F). Value combinedVal = insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter, diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir index c5760acf94a88..776dc5b748c84 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir @@ -85,31 +85,32 @@ module attributes {transform.with_named_sequence} { // CHECK: scf.yield %[[S9]] // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] // CHECK: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]] +// CHECK: %[[S7:.*]] = tensor.empty() // CHECK: %[[S6:.*]] = linalg.batch_matmul // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] -// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32> -// CHECK: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]]) +// CHECK: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]]) // CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] // CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) // CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]]) -// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] +// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] // CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) // CHECK: %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) // CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] +// CHECK: %[[S25:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] // CHECK: %[[S16:.*]] = tensor.empty() : tensor<4x6xf32> // CHECK: %[[S17:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S16]] : tensor<4x6xf32>) -> tensor<4x6xf32> // CHECK: %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_8]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32> // CHECK: %[[S19:.*]] = tensor.empty() : tensor<4x4xf32> // CHECK: %[[S20:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK: %[[S21:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK: %[[S22:.*]] = tensor.empty() : tensor<4x4xf32> -// CHECK: %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S22]] : tensor<4x4xf32>) { -// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S21]] : f32, tensor<4x4xf32>) outs(%[[S25]] : tensor<4x4xf32>) { +// CHECK: ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK: %[[VAL_90:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32 +// CHECK: %[[VAL_91:.*]] = arith.addf %[[VAL_90]], %[[OUT]] : f32 +/// CHECK: linalg.yield %[[VAL_91]] : f32 // CHECK: } -> tensor<4x4xf32> -// CHECK: %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] +// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S23]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] // CHECK: scf.yield %[[INSERTED_SLICE_9]] // CHECK: scf.yield %[[S15]] // CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) @@ -218,32 +219,33 @@ module attributes {transform.with_named_sequence} { // CHECK: scf.yield %[[S9]] // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] // CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]] +// CHECK: %[[S7:.*]] = tensor.empty() // CHECK: %[[S6:.*]] = linalg.batch_matmul // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] // CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0] -// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32> -// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]]) +// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[PADDED_8]]) // CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) // CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] // CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) // CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]]) -// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] +// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] // CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) // CHECK: %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) // CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] +// CHECK: %[[S26:.*]] = tensor.extract_slice %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] // CHECK: %[[S17:.*]] = tensor.empty() : tensor<4x6xf32> // CHECK: %[[S18:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32> // CHECK: %[[S19:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_11]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S18]] : tensor<4x6xf32>) -> tensor<4x6xf32> // CHECK: %[[S20:.*]] = tensor.empty() : tensor<4x4xf32> // CHECK: %[[S21:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK: %[[S22:.*]] = linalg.matmul ins(%[[S19]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK: %[[S23:.*]] = tensor.empty() : tensor<4x4xf32> -// CHECK: %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S23]] : tensor<4x4xf32>) { -// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S22]] : f32, tensor<4x4xf32>) outs(%[[S26]] : tensor<4x4xf32>) { +// CHECK: ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK: %[[VAL_104:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32 +// CHECK: %[[VAL_105:.*]] = arith.addf %[[VAL_104]], %[[OUT]] : f32 +/// CHECK: linalg.yield %[[VAL_105]] : f32 // CHECK: } -> tensor<4x4xf32> -// CHECK: %[[S25:.*]] = linalg.mul ins(%[[S24]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S25]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] +// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S24]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] // CHECK: scf.yield %[[INSERTED_SLICE_12]] // CHECK: scf.yield %[[S15]] : tensor<2x4x4x2xf32> // CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) @@ -330,16 +332,17 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]]) // CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] +// CHECK: %[[S15:.*]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1] // CHECK: %[[S9:.*]] = tensor.empty() : tensor<4x1xf32> // CHECK: %[[S10:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S9]] : tensor<4x1xf32>) -> tensor<4x1xf32> // CHECK: %[[S11:.*]] = linalg.matmul ins(%[[CST_0]], %[[EXTRACTED_SLICE]] : tensor<4x6xf32>, tensor<6x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32> -// CHECK: %[[S12:.*]] = tensor.empty() : tensor<4x1xf32> -// CHECK: %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S12]] : tensor<4x1xf32>) { -// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S11]] : f32, tensor<4x1xf32>) outs(%[[S15]] : tensor<4x1xf32>) { +// CHECK: ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK: %[[VAL_57:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32 +// CHECK: %[[VAL_58:.*]] = arith.addf %[[VAL_57]], %[[OUT]] : f32 +/// CHECK: linalg.yield %[[VAL_58]] : f32 // CHECK: } -> tensor<4x1xf32> -// CHECK: %[[S14:.*]] = linalg.mul ins(%[[S13]], %[[S11]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S12]] : tensor<4x1xf32>) -> tensor<4x1xf32> -// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1] +// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1] // CHECK: scf.yield %[[INSERTED_SLICE]] // CHECK: scf.yield %[[S7]] // CHECK: return %[[S6]] diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir index 21522a2083b46..9598c434aadb8 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir @@ -279,14 +279,14 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index -// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] -// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]] +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG1]]) -> (tensor<2x8x8x2xf32>) +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]] iter_args(%[[ARG6:.*]] = %[[ARG5]]) -> (tensor<2x8x8x2xf32>) // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32> // CHECK: %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]]) // CHECK: %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) // CHECK: %[[S5:.*]] = affine.apply #[[$MAP1]]() // CHECK: %[[S6:.*]] = affine.apply #[[$MAP1]]() -// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32> +// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG6]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32> // ----- @@ -321,10 +321,10 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C2_5:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C2_7:.*]] = arith.constant 2 : index -// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]] -// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]] -// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]] -// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]] +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]] iter_args(%[[ARG9:.*]] = %[[ARG1]]) -> (tensor<3x8x8x5xf32>) +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]] iter_args(%[[ARG10:.*]] = %[[ARG9]]) -> (tensor<3x8x8x5xf32>) +// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]] iter_args(%[[ARG11:.*]] = %[[ARG10]]) +// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]] iter_args(%[[ARG12:.*]] = %[[ARG11]]) // CHECK: %[[C3_8:.*]] = arith.constant 3 : index // CHECK: %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG6]]) // CHECK: %[[C5_9:.*]] = arith.constant 5 : index @@ -334,7 +334,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[S8:.*]] = affine.apply #[[$MAP2]](%[[ARG4]]) // CHECK: %[[S9:.*]] = affine.apply #[[$MAP3]]() // CHECK: %[[S10:.*]] = affine.apply #[[$MAP3]]() -// CHECK: %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor +// CHECK: %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor // ----- @@ -367,14 +367,14 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C1_4:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index -// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] -// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]] -// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]] -// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]] +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[ARG1]]) -> (tensor<3x8x1x5xf32>) +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]] iter_args(%[[ARG10:.*]] = %[[ARG9]]) -> (tensor<3x8x1x5xf32>) +// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]] iter_args(%[[ARG11:.*]] = %[[ARG10]]) -> (tensor<3x8x1x5xf32>) +// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]] iter_args(%[[ARG12:.*]] = %[[ARG11]]) -> (tensor<3x8x1x5xf32>) // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32> // CHECK: %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]]) // CHECK: %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) // CHECK: %[[S7:.*]] = affine.apply #[[$MAP1]]() // CHECK: %[[S8:.*]] = affine.apply #[[$MAP1]]() -// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32> +// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32> // CHECK: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32> diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir index 4369f5f1eab4c..16d06a7473272 100644 --- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir +++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir @@ -100,21 +100,22 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg // CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) { // CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) { // CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6xf32> +// CHECK-NEXT: %[[S20:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) +// CHECK-NEXT: %[[S21:.*]] = affine.apply #[[$MAP0]](%[[ARG5]]) +// CHECK-NEXT: %[[S22:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<4x4xf32> // CHECK-NEXT: %[[S11:.*]] = tensor.empty() : tensor<4x6xf32> // CHECK-NEXT: %[[S12:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S11]] : tensor<4x6xf32>) -> tensor<4x6xf32> // CHECK-NEXT: %[[S13:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<4x6xf32>) -> tensor<4x6xf32> // CHECK-NEXT: %[[S14:.*]] = tensor.empty() : tensor<4x4xf32> // CHECK-NEXT: %[[S15:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S14]] : tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: %[[S16:.*]] = linalg.matmul ins(%[[S13]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S15]] : tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK-NEXT: %[[S17:.*]] = tensor.empty() : tensor<4x4xf32> -// CHECK-NEXT: %[[S18:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S17]] : tensor<4x4xf32>) { -// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: %[[S18:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S16]] : f32, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) { +// CHECK-NEXT: ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[VAL_98:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32 +// CHECK-NEXT: %[[VAL_99:.*]] = arith.addf %[[VAL_98]], %[[OUT]] : f32 +// CHECK-NEXT: linalg.yield %[[VAL_99]] : f32 // CHECK-NEXT: } -> tensor<4x4xf32> -// CHECK-NEXT: %[[S19:.*]] = linalg.mul ins(%[[S18]], %[[S16]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S17]] : tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK-NEXT: %[[S20:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) -// CHECK-NEXT: %[[S21:.*]] = affine.apply #[[$MAP0]](%[[ARG5]]) -// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S18]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32> // CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32> // CHECK-NEXT: } // CHECK-NEXT: scf.yield %[[S9]] : tensor<2x12x12x2xf32>