Skip to content

Conversation

nuudlman
Copy link
Contributor

@nuudlman nuudlman commented Sep 14, 2025

The winograd transform constant array is always emitted as f32, but previously the creation would pass through the original type. If this type was smaller (like f16), you would get an assertion failure during attribute creation.

This fixes this by ensuring that the types match and adding a test for this case.

@llvmbot
Copy link
Member

llvmbot commented Sep 14, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Isaac Nudelman (nuudlman)

Changes

The winograd transform constant array is always emitted as f32, but previously the creation would pass through the original type. If this type was smaller (like f16), you would get an assertion failure during attribute creation.

This fixes this by always promoting the type of the winograd constants to f32 and adding a test for this case.


Full diff: https://github.com/llvm/llvm-project/pull/158500.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+15-15)
  • (modified) mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir (+116)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index b80b27fe5fcc5..b875b24c8fda0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -186,11 +186,12 @@ constexpr float A_2x2_5x5[] = {
 
 /// Structure to keep information of constant transform matrices.
 struct TransformMatrix {
-  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+  TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
                   int64_t scalarFactor = 1)
-      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {
+  }
 
-  const float *table;
+  ArrayRef<float> table;
   int64_t rows;
   int64_t cols;
   int64_t scalarFactor;
@@ -198,15 +199,14 @@ struct TransformMatrix {
 
 /// Utility function to convert constant array to arith.constant Value.
 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
-                              TransformMatrix transform, Type type) {
-  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
-
+                              TransformMatrix transform) {
+  assert(transform.table.size() == static_cast<size_t>(transform.rows * transform.cols));
+  ArrayRef<float> constVec(transform.table.data(), transform.rows * transform.cols);
+  SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
   return arith::ConstantOp::create(
       builder, loc,
       DenseFPElementsAttr::get(
-          RankedTensorType::get(
-              SmallVector<int64_t>{transform.rows, transform.cols}, type),
-          constVec));
+          RankedTensorType::get(shape, builder.getF32Type()), constVec));
 }
 
 /// Extract height x width data from 4D tensors.
@@ -404,7 +404,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
 
-      Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+      Value G = create2DTransformMatrix(builder, loc, GMatrix);
       // Multiply G x g.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{G, extractFilter},
@@ -427,7 +427,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
 
-      Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+      Value GT = create2DTransformMatrix(builder, loc, GTMatrix);
       // Multiply u = (G x g) x GT.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, GT},
@@ -552,7 +552,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
 
       Value BT =
-          create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+          create2DTransformMatrix(builder, loc, BTMatrix);
       // Multiply BT x d.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{BT, matmulRetValue},
@@ -575,7 +575,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       Value B =
-          create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+          create2DTransformMatrix(builder, loc, BMatrix);
       // Multiply v = (BT x d) x B.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, B},
@@ -783,7 +783,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
         init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       }
 
-      Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
+      Value AT = create2DTransformMatrix(builder, loc, ATMatrix);
       // Multiply AT x m.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{AT, matmulRetValue},
@@ -802,7 +802,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
         init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       }
 
-      Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
+      Value A = create2DTransformMatrix(builder, loc, AMatrix);
       // Multiply y = (AT x m) x A.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, A},
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index c7b0bd51308ba..4bcb9b0c2c465 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -127,3 +127,119 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
 // CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
 // CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<6x6x5x2xf16>
+  %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf16>) outs(%0 : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16> // no-crash
+  %2 = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+  %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x6x5xf16>) outs(%2 : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16> // no-crash
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+  %4 = tensor.empty() : tensor<36x2x2xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%5 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+  %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %7 : tensor<2x4x4x2xf32>
+}
+
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL:   func.func @conv2d_type_promotion(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x6x6x5xf16>,
+// CHECK-SAME:      %[[ARG1:.*]]: tensor<2x3x3x5xf16>,
+// CHECK-SAME:      %[[ARG2:.*]]: tensor<1xf32>,
+// CHECK-SAME:      %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG:           %[[VAL_0:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:           %[[VAL_1:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:           %[[VAL_2:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:           %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:           %[[VAL_4:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:           %[[VAL_5:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:           %[[VAL_6:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:           %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK-DAG:           %[[VAL_9:.*]] = arith.constant 5 : index
+// CHECK-DAG:           %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK-DAG:           %[[VAL_11:.*]] = arith.constant 0 : index
+// CHECK-DAG:           %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_13:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK-NEXT:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT:             %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT:               %[[VAL_20:.*]] = tensor.extract_slice %[[ARG1]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_11]], %[[VAL_18]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf16> to tensor<3x3xf16>
+// CHECK-NEXT:               %[[VAL_21:.*]] = tensor.empty() : tensor<6x3xf16>
+// CHECK-NEXT:               %[[VAL_22:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_21]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT:               %[[VAL_23:.*]] = linalg.matmul ins(%[[VAL_6]], %[[VAL_20]] : tensor<6x3xf32>, tensor<3x3xf16>) outs(%[[VAL_22]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT:               %[[VAL_24:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT:               %[[VAL_25:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_24]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:               %[[VAL_26:.*]] = linalg.matmul ins(%[[VAL_23]], %[[VAL_5]] : tensor<6x3xf16>, tensor<3x6xf32>) outs(%[[VAL_25]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:               %[[VAL_27:.*]] = tensor.insert_slice %[[VAL_26]] into %[[VAL_19]]{{\[}}%[[VAL_11]], %[[VAL_11]], %[[VAL_18]], %[[VAL_15]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x5x2xf16>
+// CHECK-NEXT:               scf.yield %[[VAL_27]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[VAL_17]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           %[[VAL_28:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:           %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:             %[[VAL_32:.*]] = scf.for %[[VAL_33:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_34:.*]] = %[[VAL_31]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:               %[[VAL_35:.*]] = scf.for %[[VAL_36:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_37:.*]] = %[[VAL_34]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:                 %[[VAL_38:.*]] = scf.for %[[VAL_39:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_40:.*]] = %[[VAL_37]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:                   %[[VAL_41:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_30]])
+// CHECK-NEXT:                   %[[VAL_42:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_33]])
+// CHECK-NEXT:                   %[[VAL_43:.*]] = tensor.extract_slice %[[ARG0]]{{\[}}%[[VAL_36]], %[[VAL_41]], %[[VAL_42]], %[[VAL_39]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf16> to tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_44:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_45:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_44]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_46:.*]] = linalg.matmul ins(%[[VAL_4]], %[[VAL_43]] : tensor<6x6xf32>, tensor<6x6xf16>) outs(%[[VAL_45]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_47:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_48:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_47]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_49:.*]] = linalg.matmul ins(%[[VAL_46]], %[[VAL_3]] : tensor<6x6xf16>, tensor<6x6xf32>) outs(%[[VAL_48]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_50:.*]] = tensor.insert_slice %[[VAL_49]] into %[[VAL_40]][0, 0, %[[VAL_30]], %[[VAL_33]], %[[VAL_36]], %[[VAL_39]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:                   scf.yield %[[VAL_50]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:                 }
+// CHECK-NEXT:                 scf.yield %[[VAL_38]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:               }
+// CHECK-NEXT:               scf.yield %[[VAL_35]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[VAL_32]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           %[[VAL_51:.*]] = tensor.collapse_shape %[[VAL_14]] {{\[\[}}0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+// CHECK-NEXT:           %[[VAL_52:.*]] = tensor.collapse_shape %[[VAL_29]] {{\[\[}}0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+// CHECK-NEXT:           %[[VAL_53:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:           %[[VAL_54:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_53]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:           %[[VAL_55:.*]] = linalg.batch_matmul ins(%[[VAL_52]], %[[VAL_51]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[VAL_54]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:           %[[VAL_56:.*]] = tensor.expand_shape %[[VAL_55]] {{\[\[}}0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:           %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_59:.*]] = %[[ARG3]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:             %[[VAL_60:.*]] = scf.for %[[VAL_61:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_62:.*]] = %[[VAL_59]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:               %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_65:.*]] = %[[VAL_62]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:                 %[[VAL_66:.*]] = scf.for %[[VAL_67:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_68:.*]] = %[[VAL_65]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:                   %[[VAL_69:.*]] = tensor.extract_slice %[[VAL_56]][0, 0, %[[VAL_58]], %[[VAL_61]], %[[VAL_64]], %[[VAL_67]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6xf32>
+// CHECK-NEXT:                   %[[VAL_70:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_58]])
+// CHECK-NEXT:                   %[[VAL_71:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_61]])
+// CHECK-NEXT:                   %[[VAL_72:.*]] = tensor.extract_slice %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x4x4x2xf32> to tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_73:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:                   %[[VAL_74:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_73]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:                   %[[VAL_75:.*]] = linalg.matmul ins(%[[VAL_2]], %[[VAL_69]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[VAL_74]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:                   %[[VAL_76:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_77:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_76]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_78:.*]] = linalg.matmul ins(%[[VAL_75]], %[[VAL_1]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[VAL_77]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_79:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_78]] : f32, tensor<4x4xf32>) outs(%[[VAL_72]] : tensor<4x4xf32>) {
+// CHECK-NEXT:                   ^bb0(%[[VAL_80:.*]]: f32, %[[VAL_81:.*]]: f32, %[[VAL_82:.*]]: f32):
+// CHECK-NEXT:                     %[[VAL_83:.*]] = arith.mulf %[[VAL_80]], %[[VAL_81]] : f32
+// CHECK-NEXT:                     %[[VAL_84:.*]] = arith.addf %[[VAL_83]], %[[VAL_82]] : f32
+// CHECK-NEXT:                     linalg.yield %[[VAL_84]] : f32
+// CHECK-NEXT:                   } -> tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_85:.*]] = tensor.insert_slice %[[VAL_79]] into %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:                   scf.yield %[[VAL_85]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:                 }
+// CHECK-NEXT:                 scf.yield %[[VAL_66]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:               }
+// CHECK-NEXT:               scf.yield %[[VAL_63]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[VAL_60]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           return %[[VAL_57]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:         }

Copy link

github-actions bot commented Sep 14, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@nuudlman
Copy link
Contributor Author

@hanhanW @nirvedhmeshram Would one of you be able to review this?

@makslevental makslevental requested a review from Hsiangkai October 1, 2025 18:18
@makslevental
Copy link
Contributor

ping @Hsiangkai

Hsiangkai
Hsiangkai previously approved these changes Oct 2, 2025
Copy link
Contributor

@Hsiangkai Hsiangkai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Add comments below.

@Hsiangkai Hsiangkai dismissed their stale review October 2, 2025 06:16

Why not all use the element type of input/filter values for these constant matrices?

@Hsiangkai
Copy link
Contributor

Something like

--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -201,12 +201,16 @@ Value create2DTransformMatrix(OpBuilder &builder, Location loc,
                               TransformMatrix transform, Type type) {
   ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);

+  SmallVector<Attribute> constAttrVec;
+  for (float v : constVec)
+    constAttrVec.push_back(builder.getFloatAttr(type, v));
+
   return arith::ConstantOp::create(
       builder, loc,
       DenseFPElementsAttr::get(
           RankedTensorType::get(
               SmallVector<int64_t>{transform.rows, transform.cols}, type),
-          constVec));
+          constAttrVec));
 }

 /// Extract height x width data from 4D tensors.
@@ -552,7 +556,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);

       Value BT =
-          create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+          create2DTransformMatrix(builder, loc, BTMatrix, elementType);
       // Multiply BT x d.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{BT, matmulRetValue},
@@ -575,7 +579,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       Value B =
-          create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+          create2DTransformMatrix(builder, loc, BMatrix, elementType);
       // Multiply v = (BT x d) x B.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, B},

@nuudlman
Copy link
Contributor Author

nuudlman commented Oct 7, 2025

Thanks for this, I tried to figure out how to emit f16 attributes and failed. Will update to do this instead.

@nuudlman nuudlman requested a review from Hsiangkai October 7, 2025 20:46
@nuudlman
Copy link
Contributor Author

nuudlman commented Oct 8, 2025

Please let me know if there is anything else you'd like to see changed/improved @Hsiangkai

Copy link
Contributor

@Hsiangkai Hsiangkai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nuudlman
Copy link
Contributor Author

nuudlman commented Oct 8, 2025

Can you please merge it for me? I don't have commit access.

@vabridgers vabridgers merged commit cc14b58 into llvm:main Oct 8, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants