-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Linalg] Fix winograd op lowering for types smaller than f32 #158500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Isaac Nudelman (nuudlman) ChangesThe winograd transform constant array is always emitted as f32, but previously the creation would pass through the original type. If this type was smaller (like f16), you would get an assertion failure during attribute creation. This fixes this by always promoting the type of the winograd constants to f32 and adding a test for this case. Full diff: https://github.com/llvm/llvm-project/pull/158500.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index b80b27fe5fcc5..b875b24c8fda0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -186,11 +186,12 @@ constexpr float A_2x2_5x5[] = {
/// Structure to keep information of constant transform matrices.
struct TransformMatrix {
- TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
int64_t scalarFactor = 1)
- : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+ : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {
+ }
- const float *table;
+ ArrayRef<float> table;
int64_t rows;
int64_t cols;
int64_t scalarFactor;
@@ -198,15 +199,14 @@ struct TransformMatrix {
/// Utility function to convert constant array to arith.constant Value.
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
- TransformMatrix transform, Type type) {
- ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
-
+ TransformMatrix transform) {
+ assert(transform.table.size() == static_cast<size_t>(transform.rows * transform.cols));
+ ArrayRef<float> constVec(transform.table.data(), transform.rows * transform.cols);
+ SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
return arith::ConstantOp::create(
builder, loc,
DenseFPElementsAttr::get(
- RankedTensorType::get(
- SmallVector<int64_t>{transform.rows, transform.cols}, type),
- constVec));
+ RankedTensorType::get(shape, builder.getF32Type()), constVec));
}
/// Extract height x width data from 4D tensors.
@@ -404,7 +404,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
- Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+ Value G = create2DTransformMatrix(builder, loc, GMatrix);
// Multiply G x g.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{G, extractFilter},
@@ -427,7 +427,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
- Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+ Value GT = create2DTransformMatrix(builder, loc, GTMatrix);
// Multiply u = (G x g) x GT.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{matmulRetValue, GT},
@@ -552,7 +552,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
Value BT =
- create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+ create2DTransformMatrix(builder, loc, BTMatrix);
// Multiply BT x d.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{BT, matmulRetValue},
@@ -575,7 +575,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
Value B =
- create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+ create2DTransformMatrix(builder, loc, BMatrix);
// Multiply v = (BT x d) x B.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{matmulRetValue, B},
@@ -783,7 +783,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
}
- Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
+ Value AT = create2DTransformMatrix(builder, loc, ATMatrix);
// Multiply AT x m.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{AT, matmulRetValue},
@@ -802,7 +802,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
}
- Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
+ Value A = create2DTransformMatrix(builder, loc, AMatrix);
// Multiply y = (AT x m) x A.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{matmulRetValue, A},
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index c7b0bd51308ba..4bcb9b0c2c465 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -127,3 +127,119 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<6x6x5x2xf16>
+ %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf16>) outs(%0 : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16> // no-crash
+ %2 = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+ %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x6x5xf16>) outs(%2 : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16> // no-crash
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+ %4 = tensor.empty() : tensor<36x2x2xf32>
+ %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+ %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%5 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+ %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+ %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+ return %7 : tensor<2x4x4x2xf32>
+}
+
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_type_promotion(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x6x6x5xf16>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<2x3x3x5xf16>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<1xf32>,
+// CHECK-SAME: %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK-NEXT: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT: %[[VAL_20:.*]] = tensor.extract_slice %[[ARG1]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_11]], %[[VAL_18]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf16> to tensor<3x3xf16>
+// CHECK-NEXT: %[[VAL_21:.*]] = tensor.empty() : tensor<6x3xf16>
+// CHECK-NEXT: %[[VAL_22:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_21]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT: %[[VAL_23:.*]] = linalg.matmul ins(%[[VAL_6]], %[[VAL_20]] : tensor<6x3xf32>, tensor<3x3xf16>) outs(%[[VAL_22]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT: %[[VAL_24:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_25:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_24]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_26:.*]] = linalg.matmul ins(%[[VAL_23]], %[[VAL_5]] : tensor<6x3xf16>, tensor<3x6xf32>) outs(%[[VAL_25]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_27:.*]] = tensor.insert_slice %[[VAL_26]] into %[[VAL_19]]{{\[}}%[[VAL_11]], %[[VAL_11]], %[[VAL_18]], %[[VAL_15]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x5x2xf16>
+// CHECK-NEXT: scf.yield %[[VAL_27]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_17]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[VAL_28:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_32:.*]] = scf.for %[[VAL_33:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_34:.*]] = %[[VAL_31]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_35:.*]] = scf.for %[[VAL_36:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_37:.*]] = %[[VAL_34]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_38:.*]] = scf.for %[[VAL_39:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_40:.*]] = %[[VAL_37]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_41:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_30]])
+// CHECK-NEXT: %[[VAL_42:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_33]])
+// CHECK-NEXT: %[[VAL_43:.*]] = tensor.extract_slice %[[ARG0]]{{\[}}%[[VAL_36]], %[[VAL_41]], %[[VAL_42]], %[[VAL_39]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf16> to tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_44:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_45:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_44]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_46:.*]] = linalg.matmul ins(%[[VAL_4]], %[[VAL_43]] : tensor<6x6xf32>, tensor<6x6xf16>) outs(%[[VAL_45]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_47:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_48:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_47]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_49:.*]] = linalg.matmul ins(%[[VAL_46]], %[[VAL_3]] : tensor<6x6xf16>, tensor<6x6xf32>) outs(%[[VAL_48]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_50:.*]] = tensor.insert_slice %[[VAL_49]] into %[[VAL_40]][0, 0, %[[VAL_30]], %[[VAL_33]], %[[VAL_36]], %[[VAL_39]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: scf.yield %[[VAL_50]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_38]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_35]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_32]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[VAL_51:.*]] = tensor.collapse_shape %[[VAL_14]] {{\[\[}}0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+// CHECK-NEXT: %[[VAL_52:.*]] = tensor.collapse_shape %[[VAL_29]] {{\[\[}}0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+// CHECK-NEXT: %[[VAL_53:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT: %[[VAL_54:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_53]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[VAL_55:.*]] = linalg.batch_matmul ins(%[[VAL_52]], %[[VAL_51]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[VAL_54]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[VAL_56:.*]] = tensor.expand_shape %[[VAL_55]] {{\[\[}}0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_59:.*]] = %[[ARG3]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_60:.*]] = scf.for %[[VAL_61:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_62:.*]] = %[[VAL_59]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_65:.*]] = %[[VAL_62]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_66:.*]] = scf.for %[[VAL_67:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_68:.*]] = %[[VAL_65]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_69:.*]] = tensor.extract_slice %[[VAL_56]][0, 0, %[[VAL_58]], %[[VAL_61]], %[[VAL_64]], %[[VAL_67]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6xf32>
+// CHECK-NEXT: %[[VAL_70:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_58]])
+// CHECK-NEXT: %[[VAL_71:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_61]])
+// CHECK-NEXT: %[[VAL_72:.*]] = tensor.extract_slice %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x4x4x2xf32> to tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_73:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT: %[[VAL_74:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_73]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT: %[[VAL_75:.*]] = linalg.matmul ins(%[[VAL_2]], %[[VAL_69]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[VAL_74]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT: %[[VAL_76:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_77:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_76]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_78:.*]] = linalg.matmul ins(%[[VAL_75]], %[[VAL_1]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[VAL_77]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_79:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_78]] : f32, tensor<4x4xf32>) outs(%[[VAL_72]] : tensor<4x4xf32>) {
+// CHECK-NEXT: ^bb0(%[[VAL_80:.*]]: f32, %[[VAL_81:.*]]: f32, %[[VAL_82:.*]]: f32):
+// CHECK-NEXT: %[[VAL_83:.*]] = arith.mulf %[[VAL_80]], %[[VAL_81]] : f32
+// CHECK-NEXT: %[[VAL_84:.*]] = arith.addf %[[VAL_83]], %[[VAL_82]] : f32
+// CHECK-NEXT: linalg.yield %[[VAL_84]] : f32
+// CHECK-NEXT: } -> tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_85:.*]] = tensor.insert_slice %[[VAL_79]] into %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT: scf.yield %[[VAL_85]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_66]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_63]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_60]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[VAL_57]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
@hanhanW @nirvedhmeshram Would one of you be able to review this? |
ping @Hsiangkai |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Add comments below.
Why not all use the element type of input/filter values for these constant matrices?
Something like
|
Thanks for this, I tried to figure out how to emit f16 attributes and failed. Will update to do this instead. |
Please let me know if there is anything else you'd like to see changed/improved @Hsiangkai |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
The winograd transform constant array is always emitted as f32, but previously the creation would pass through the original type. If this type was smaller (like f16), you would get an assertion failure during attribute creation.
This fixes this by ensuring that the types match and adding a test for this case.