Skip to content

Commit d4d2121

Browse files
committed
add a concept of a fix
1 parent ba3b3e3 commit d4d2121

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,46 +46,46 @@ namespace {
4646
/// BTMatrices, BMatrices, ATMatrices, or AMatrices map.
4747
/// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
4848
///
49-
constexpr float G_2x2_3x3[] = {
49+
constexpr double G_2x2_3x3[] = {
5050
-1, 0, 0,
5151
1./2, -1./2, 1./2,
5252
1./2, 1./2, 1./2,
5353
0, 0, 1
5454
};
5555

56-
constexpr float GT_2x2_3x3[] = {
56+
constexpr double GT_2x2_3x3[] = {
5757
-1, 1./2, 1./2, 0,
5858
0, -1./2, 1./2, 0,
5959
0, 1./2, 1./2, 1
6060
};
6161

62-
constexpr float BT_2x2_3x3[] = {
62+
constexpr double BT_2x2_3x3[] = {
6363
-1, 0, 1, 0,
6464
0, -1, 1, 0,
6565
0, 1, 1, 0,
6666
0, -1, 0, 1
6767
};
6868

69-
constexpr float B_2x2_3x3[] = {
69+
constexpr double B_2x2_3x3[] = {
7070
-1, 0, 0, 0,
7171
0, -1, 1, -1,
7272
1, 1, 1, 0,
7373
0, 0, 0, 1
7474
};
7575

76-
constexpr float AT_2x2_3x3[] = {
76+
constexpr double AT_2x2_3x3[] = {
7777
1, 1, 1, 0,
7878
0, -1, 1, 1
7979
};
8080

81-
constexpr float A_2x2_3x3[] = {
81+
constexpr double A_2x2_3x3[] = {
8282
1, 0,
8383
1, -1,
8484
1, 1,
8585
0, 1
8686
};
8787

88-
constexpr float G_4x4_3x3[] = {
88+
constexpr double G_4x4_3x3[] = {
8989
1, 0, 0,
9090
-1./3, 1./3, -1./3,
9191
-1./3, -1./3, -1./3,
@@ -94,13 +94,13 @@ constexpr float G_4x4_3x3[] = {
9494
0, 0, 1
9595
};
9696

97-
constexpr float GT_4x4_3x3[] = {
97+
constexpr double GT_4x4_3x3[] = {
9898
1, -1./3, -1./3, 1./12, 1./12, 0,
9999
0, 1./3, -1./3, -1./6, 1./6, 0,
100100
0, -1./3, -1./3, 1./3, 1./3, 1
101101
};
102102

103-
constexpr float BT_4x4_3x3[] = {
103+
constexpr double BT_4x4_3x3[] = {
104104
1./4, 0, -5./16, 0, 1./16, 0,
105105
0, 1./4, -1./4, -1./16, 1./16, 0,
106106
0, -1./4, -1./4, 1./16, 1./16, 0,
@@ -109,7 +109,7 @@ constexpr float BT_4x4_3x3[] = {
109109
0, 1./4, 0, -5./16, 0, 1./16
110110
};
111111

112-
constexpr float B_4x4_3x3[] = {
112+
constexpr double B_4x4_3x3[] = {
113113
1./4, 0, 0, 0, 0, 0,
114114
0, 1./4, -1./4, 1./4, -1./4, 1./4,
115115
-5./16, -1./4, -1./4, -1./8, -1./8, 0,
@@ -118,14 +118,14 @@ constexpr float B_4x4_3x3[] = {
118118
0, 0, 0, 0, 0, 1./16
119119
};
120120

121-
constexpr float AT_4x4_3x3[] = {
121+
constexpr double AT_4x4_3x3[] = {
122122
1./8, 1./4, 1./4, 1./8, 1./8, 0,
123123
0, -1./4, 1./4, -1./4, 1./4, 0,
124124
0, 1./4, 1./4, 1./2, 1./2, 0,
125125
0, -1./4, 1./4, -1, 1, 1./2
126126
};
127127

128-
constexpr float A_4x4_3x3[] = {
128+
constexpr double A_4x4_3x3[] = {
129129
1./8, 0, 0, 0,
130130
1./4, -1./4, 1./4, -1./4,
131131
1./4, 1./4, 1./4, 1./4,
@@ -134,7 +134,7 @@ constexpr float A_4x4_3x3[] = {
134134
0, 0, 0, 1./2
135135
};
136136

137-
constexpr float G_2x2_5x5[] = {
137+
constexpr double G_2x2_5x5[] = {
138138
1, 0, 0, 0, 0,
139139
1./6, -1./6, 1./6, -1./6, 1./6,
140140
-1./6, -1./6, -1./6, -1./6, -1./6,
@@ -143,15 +143,15 @@ constexpr float G_2x2_5x5[] = {
143143
0, 0, 0, 0, 1
144144
};
145145

146-
constexpr float GT_2x2_5x5[] = {
146+
constexpr double GT_2x2_5x5[] = {
147147
1, 1./6, -1./6, -4./15, 1./60, 0,
148148
0, -1./6, -1./6, 2./15, 1./30, 0,
149149
0, 1./6, -1./6, -1./15, 1./15, 0,
150150
0, -1./6, -1./6, 1./30, 2./15, 0,
151151
0, 1./6, -1./6, -1./60, 4./15, 1
152152
};
153153

154-
constexpr float BT_2x2_5x5[] = {
154+
constexpr double BT_2x2_5x5[] = {
155155
1./8, 3./16, -1./4, -3./16, 1./8, 0,
156156
0, 1./8, 1./16, -5./16, 1./8, 0,
157157
0, -1./8, -5./16, -1./16, 1./8, 0,
@@ -160,7 +160,7 @@ constexpr float BT_2x2_5x5[] = {
160160
0, 1./8, 3./16, -1./4, -3./16, 1./8
161161
};
162162

163-
constexpr float B_2x2_5x5[] = {
163+
constexpr double B_2x2_5x5[] = {
164164
1./8, 0, 0, 0, 0, 0,
165165
3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
166166
-1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
@@ -169,12 +169,12 @@ constexpr float B_2x2_5x5[] = {
169169
0, 0, 0, 0, 0, 1./8
170170
};
171171

172-
constexpr float AT_2x2_5x5[] = {
172+
constexpr double AT_2x2_5x5[] = {
173173
1./2, 1, 1, 2, 1, 0,
174174
0, -1, 1, -1, 2, 1./2
175175
};
176176

177-
constexpr float A_2x2_5x5[] = {
177+
constexpr double A_2x2_5x5[] = {
178178
1./2, 0,
179179
1, -1,
180180
1, 1,
@@ -186,11 +186,12 @@ constexpr float A_2x2_5x5[] = {
186186

187187
/// Structure to keep information of constant transform matrices.
188188
struct TransformMatrix {
189-
TransformMatrix(const float *table, int64_t rows, int64_t cols,
189+
TransformMatrix(ArrayRef<double> table, int64_t rows, int64_t cols,
190190
int64_t scalarFactor = 1)
191-
: table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
191+
: table(llvm::map_to_vector(table, [](double val) { return APFloat(val); })), rows(rows), cols(cols), scalarFactor(scalarFactor) {
192+
}
192193

193-
const float *table;
194+
SmallVector<APFloat> table;
194195
int64_t rows;
195196
int64_t cols;
196197
int64_t scalarFactor;
@@ -199,7 +200,9 @@ struct TransformMatrix {
199200
/// Utility function to convert constant array to arith.constant Value.
200201
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
201202
TransformMatrix transform, Type type) {
202-
ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
203+
assert(type.isFloat());
204+
assert(transform.table.size() == (transform.rows * transform.cols));
205+
ArrayRef<APFloat> constVec(transform.table.data(), transform.rows * transform.cols);
203206

204207
return arith::ConstantOp::create(
205208
builder, loc,

mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,21 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
127127
// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
128128
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
129129
// CHECK-NEXT: }
130+
131+
// -----
132+
133+
func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
134+
%cst = arith.constant 0.000000e+00 : f32
135+
%0 = tensor.empty() : tensor<6x6x5x2xf16>
136+
%1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf16>) outs(%0 : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16> // no-crash
137+
%2 = tensor.empty() : tensor<6x6x1x1x2x5xf16>
138+
%3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x6x5xf16>) outs(%2 : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16> // no-crash
139+
%collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
140+
%collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
141+
%4 = tensor.empty() : tensor<36x2x2xf32>
142+
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
143+
%6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%5 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
144+
%expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
145+
%7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
146+
return %7 : tensor<2x4x4x2xf32>
147+
}

0 commit comments

Comments
 (0)