Skip to content

Commit cc14b58

Browse files
authored
[MLIR][Linalg] Fix winograd op lowering for types smaller than f32 (#158500)
The winograd transform constant array is always emitted as f32, but previously the creation would pass through the original type. If this type was smaller (like f16), you would get an assertion failure during attribute creation. This fixes this by ensuring that the types match and adding a test for this case.
1 parent 1395d43 commit cc14b58

File tree

2 files changed

+132
-12
lines changed

2 files changed

+132
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,11 @@ constexpr float A_2x2_5x5[] = {
186186

187187
/// Structure to keep information of constant transform matrices.
188188
struct TransformMatrix {
189-
TransformMatrix(const float *table, int64_t rows, int64_t cols,
189+
TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
190190
int64_t scalarFactor = 1)
191191
: table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
192192

193-
const float *table;
193+
ArrayRef<float> table;
194194
int64_t rows;
195195
int64_t cols;
196196
int64_t scalarFactor;
@@ -199,14 +199,20 @@ struct TransformMatrix {
199199
/// Utility function to convert constant array to arith.constant Value.
200200
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
201201
TransformMatrix transform, Type type) {
202-
ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
203-
202+
assert(transform.table.size() ==
203+
static_cast<size_t>(transform.rows * transform.cols));
204+
assert(type.isFloat() && "Only floats are supported by Winograd");
205+
ArrayRef<float> constVec(transform.table.data(),
206+
transform.rows * transform.cols);
207+
auto constAttrVec =
208+
llvm::map_to_vector<>(constVec, [&](const float v) -> Attribute {
209+
return builder.getFloatAttr(type, v);
210+
});
211+
SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
204212
return arith::ConstantOp::create(
205213
builder, loc,
206-
DenseFPElementsAttr::get(
207-
RankedTensorType::get(
208-
SmallVector<int64_t>{transform.rows, transform.cols}, type),
209-
constVec));
214+
DenseFPElementsAttr::get(RankedTensorType::get(shape, type),
215+
constAttrVec));
210216
}
211217

212218
/// Extract height x width data from 4D tensors.
@@ -551,8 +557,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
551557
auto init =
552558
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
553559

554-
Value BT =
555-
create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
560+
Value BT = create2DTransformMatrix(builder, loc, BTMatrix, elementType);
556561
// Multiply BT x d.
557562
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
558563
ValueRange{BT, matmulRetValue},
@@ -574,8 +579,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
574579
.getResult();
575580
auto init =
576581
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
577-
Value B =
578-
create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
582+
Value B = create2DTransformMatrix(builder, loc, BMatrix, elementType);
579583
// Multiply v = (BT x d) x B.
580584
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
581585
ValueRange{matmulRetValue, B},

mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,119 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
127127
// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
128128
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
129129
// CHECK-NEXT: }
130+
131+
// -----
132+
133+
func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
134+
%cst = arith.constant 0.000000e+00 : f32
135+
%0 = tensor.empty() : tensor<6x6x5x2xf16>
136+
%1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf16>) outs(%0 : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16> // no-crash
137+
%2 = tensor.empty() : tensor<6x6x1x1x2x5xf16>
138+
%3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x6x5xf16>) outs(%2 : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16> // no-crash
139+
%collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
140+
%collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
141+
%4 = tensor.empty() : tensor<36x2x2xf32>
142+
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
143+
%6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%5 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
144+
%expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
145+
%7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
146+
return %7 : tensor<2x4x4x2xf32>
147+
}
148+
149+
150+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 * 4)>
151+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> ()>
152+
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
153+
// CHECK-LABEL: func.func @conv2d_type_promotion(
154+
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x6x6x5xf16>,
155+
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x3x3x5xf16>,
156+
// CHECK-SAME: %[[ARG2:.*]]: tensor<1xf32>,
157+
// CHECK-SAME: %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
158+
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 1.024000e+03 : f32
159+
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
160+
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
161+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf16>
162+
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf16>
163+
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, -3.332520e-01, -3.332520e-01, 8.331300e-02, 8.331300e-02, 0.000000e+00], [0.000000e+00, 3.332520e-01, -3.332520e-01, -1.666260e-01, 1.666260e-01, 0.000000e+00], [0.000000e+00, -3.332520e-01, -3.332520e-01, 3.332520e-01, 3.332520e-01, 1.000000e+00]]> : tensor<3x6xf16>
164+
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, 0.000000e+00, 0.000000e+00], [-3.332520e-01, 3.332520e-01, -3.332520e-01], [-3.332520e-01, -3.332520e-01, -3.332520e-01], [8.331300e-02, -1.666260e-01, 3.332520e-01], [8.331300e-02, 1.666260e-01, 3.332520e-01], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf16>
165+
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f16
166+
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index
167+
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 5 : index
168+
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 2 : index
169+
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index
170+
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
171+
// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
172+
// CHECK-NEXT: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (tensor<6x6x5x2xf16>) {
173+
// CHECK-NEXT: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (tensor<6x6x5x2xf16>) {
174+
// CHECK-NEXT: %[[VAL_20:.*]] = tensor.extract_slice %[[ARG1]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_11]], %[[VAL_18]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf16> to tensor<3x3xf16>
175+
// CHECK-NEXT: %[[VAL_21:.*]] = tensor.empty() : tensor<6x3xf16>
176+
// CHECK-NEXT: %[[VAL_22:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_21]] : tensor<6x3xf16>) -> tensor<6x3xf16>
177+
// CHECK-NEXT: %[[VAL_23:.*]] = linalg.matmul ins(%[[VAL_6]], %[[VAL_20]] : tensor<6x3xf16>, tensor<3x3xf16>) outs(%[[VAL_22]] : tensor<6x3xf16>) -> tensor<6x3xf16>
178+
// CHECK-NEXT: %[[VAL_24:.*]] = tensor.empty() : tensor<6x6xf16>
179+
// CHECK-NEXT: %[[VAL_25:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_24]] : tensor<6x6xf16>) -> tensor<6x6xf16>
180+
// CHECK-NEXT: %[[VAL_26:.*]] = linalg.matmul ins(%[[VAL_23]], %[[VAL_5]] : tensor<6x3xf16>, tensor<3x6xf16>) outs(%[[VAL_25]] : tensor<6x6xf16>) -> tensor<6x6xf16>
181+
// CHECK-NEXT: %[[VAL_27:.*]] = tensor.insert_slice %[[VAL_26]] into %[[VAL_19]]{{\[}}%[[VAL_11]], %[[VAL_11]], %[[VAL_18]], %[[VAL_15]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x5x2xf16>
182+
// CHECK-NEXT: scf.yield %[[VAL_27]] : tensor<6x6x5x2xf16>
183+
// CHECK-NEXT: }
184+
// CHECK-NEXT: scf.yield %[[VAL_17]] : tensor<6x6x5x2xf16>
185+
// CHECK-NEXT: }
186+
// CHECK-NEXT: %[[VAL_28:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
187+
// CHECK-NEXT: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (tensor<6x6x1x1x2x5xf16>) {
188+
// CHECK-NEXT: %[[VAL_32:.*]] = scf.for %[[VAL_33:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_34:.*]] = %[[VAL_31]]) -> (tensor<6x6x1x1x2x5xf16>) {
189+
// CHECK-NEXT: %[[VAL_35:.*]] = scf.for %[[VAL_36:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_37:.*]] = %[[VAL_34]]) -> (tensor<6x6x1x1x2x5xf16>) {
190+
// CHECK-NEXT: %[[VAL_38:.*]] = scf.for %[[VAL_39:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_40:.*]] = %[[VAL_37]]) -> (tensor<6x6x1x1x2x5xf16>) {
191+
// CHECK-NEXT: %[[VAL_41:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_30]])
192+
// CHECK-NEXT: %[[VAL_42:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_33]])
193+
// CHECK-NEXT: %[[VAL_43:.*]] = tensor.extract_slice %[[ARG0]]{{\[}}%[[VAL_36]], %[[VAL_41]], %[[VAL_42]], %[[VAL_39]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf16> to tensor<6x6xf16>
194+
// CHECK-NEXT: %[[VAL_44:.*]] = tensor.empty() : tensor<6x6xf16>
195+
// CHECK-NEXT: %[[VAL_45:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_44]] : tensor<6x6xf16>) -> tensor<6x6xf16>
196+
// CHECK-NEXT: %[[VAL_46:.*]] = linalg.matmul ins(%[[VAL_4]], %[[VAL_43]] : tensor<6x6xf16>, tensor<6x6xf16>) outs(%[[VAL_45]] : tensor<6x6xf16>) -> tensor<6x6xf16>
197+
// CHECK-NEXT: %[[VAL_47:.*]] = tensor.empty() : tensor<6x6xf16>
198+
// CHECK-NEXT: %[[VAL_48:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_47]] : tensor<6x6xf16>) -> tensor<6x6xf16>
199+
// CHECK-NEXT: %[[VAL_49:.*]] = linalg.matmul ins(%[[VAL_46]], %[[VAL_3]] : tensor<6x6xf16>, tensor<6x6xf16>) outs(%[[VAL_48]] : tensor<6x6xf16>) -> tensor<6x6xf16>
200+
// CHECK-NEXT: %[[VAL_50:.*]] = tensor.insert_slice %[[VAL_49]] into %[[VAL_40]][0, 0, %[[VAL_30]], %[[VAL_33]], %[[VAL_36]], %[[VAL_39]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x1x1x2x5xf16>
201+
// CHECK-NEXT: scf.yield %[[VAL_50]] : tensor<6x6x1x1x2x5xf16>
202+
// CHECK-NEXT: }
203+
// CHECK-NEXT: scf.yield %[[VAL_38]] : tensor<6x6x1x1x2x5xf16>
204+
// CHECK-NEXT: }
205+
// CHECK-NEXT: scf.yield %[[VAL_35]] : tensor<6x6x1x1x2x5xf16>
206+
// CHECK-NEXT: }
207+
// CHECK-NEXT: scf.yield %[[VAL_32]] : tensor<6x6x1x1x2x5xf16>
208+
// CHECK-NEXT: }
209+
// CHECK-NEXT: %[[VAL_51:.*]] = tensor.collapse_shape %[[VAL_14]] {{\[\[}}0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
210+
// CHECK-NEXT: %[[VAL_52:.*]] = tensor.collapse_shape %[[VAL_29]] {{\[\[}}0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
211+
// CHECK-NEXT: %[[VAL_53:.*]] = tensor.empty() : tensor<36x2x2xf32>
212+
// CHECK-NEXT: %[[VAL_54:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_53]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
213+
// CHECK-NEXT: %[[VAL_55:.*]] = linalg.batch_matmul ins(%[[VAL_52]], %[[VAL_51]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[VAL_54]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
214+
// CHECK-NEXT: %[[VAL_56:.*]] = tensor.expand_shape %[[VAL_55]] {{\[\[}}0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
215+
// CHECK-NEXT: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_59:.*]] = %[[ARG3]]) -> (tensor<2x4x4x2xf32>) {
216+
// CHECK-NEXT: %[[VAL_60:.*]] = scf.for %[[VAL_61:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_62:.*]] = %[[VAL_59]]) -> (tensor<2x4x4x2xf32>) {
217+
// CHECK-NEXT: %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_65:.*]] = %[[VAL_62]]) -> (tensor<2x4x4x2xf32>) {
218+
// CHECK-NEXT: %[[VAL_66:.*]] = scf.for %[[VAL_67:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_68:.*]] = %[[VAL_65]]) -> (tensor<2x4x4x2xf32>) {
219+
// CHECK-NEXT: %[[VAL_69:.*]] = tensor.extract_slice %[[VAL_56]][0, 0, %[[VAL_58]], %[[VAL_61]], %[[VAL_64]], %[[VAL_67]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6xf32>
220+
// CHECK-NEXT: %[[VAL_70:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_58]])
221+
// CHECK-NEXT: %[[VAL_71:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_61]])
222+
// CHECK-NEXT: %[[VAL_72:.*]] = tensor.extract_slice %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x4x4x2xf32> to tensor<4x4xf32>
223+
// CHECK-NEXT: %[[VAL_73:.*]] = tensor.empty() : tensor<4x6xf32>
224+
// CHECK-NEXT: %[[VAL_74:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_73]] : tensor<4x6xf32>) -> tensor<4x6xf32>
225+
// CHECK-NEXT: %[[VAL_75:.*]] = linalg.matmul ins(%[[VAL_2]], %[[VAL_69]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[VAL_74]] : tensor<4x6xf32>) -> tensor<4x6xf32>
226+
// CHECK-NEXT: %[[VAL_76:.*]] = tensor.empty() : tensor<4x4xf32>
227+
// CHECK-NEXT: %[[VAL_77:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_76]] : tensor<4x4xf32>) -> tensor<4x4xf32>
228+
// CHECK-NEXT: %[[VAL_78:.*]] = linalg.matmul ins(%[[VAL_75]], %[[VAL_1]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[VAL_77]] : tensor<4x4xf32>) -> tensor<4x4xf32>
229+
// CHECK-NEXT: %[[VAL_79:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_78]] : f32, tensor<4x4xf32>) outs(%[[VAL_72]] : tensor<4x4xf32>) {
230+
// CHECK-NEXT: ^bb0(%[[VAL_80:.*]]: f32, %[[VAL_81:.*]]: f32, %[[VAL_82:.*]]: f32):
231+
// CHECK-NEXT: %[[VAL_83:.*]] = arith.mulf %[[VAL_80]], %[[VAL_81]] : f32
232+
// CHECK-NEXT: %[[VAL_84:.*]] = arith.addf %[[VAL_83]], %[[VAL_82]] : f32
233+
// CHECK-NEXT: linalg.yield %[[VAL_84]] : f32
234+
// CHECK-NEXT: } -> tensor<4x4xf32>
235+
// CHECK-NEXT: %[[VAL_85:.*]] = tensor.insert_slice %[[VAL_79]] into %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x4x4x2xf32>
236+
// CHECK-NEXT: scf.yield %[[VAL_85]] : tensor<2x4x4x2xf32>
237+
// CHECK-NEXT: }
238+
// CHECK-NEXT: scf.yield %[[VAL_66]] : tensor<2x4x4x2xf32>
239+
// CHECK-NEXT: }
240+
// CHECK-NEXT: scf.yield %[[VAL_63]] : tensor<2x4x4x2xf32>
241+
// CHECK-NEXT: }
242+
// CHECK-NEXT: scf.yield %[[VAL_60]] : tensor<2x4x4x2xf32>
243+
// CHECK-NEXT: }
244+
// CHECK-NEXT: return %[[VAL_57]] : tensor<2x4x4x2xf32>
245+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)