Skip to content

Commit 4ca18ad

Browse files
committed
Always promote winograd lowering to f32
1 parent d4d2121 commit 4ca18ad

File tree

1 file changed

+34
-35
lines changed

1 file changed

+34
-35
lines changed

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,46 +46,46 @@ namespace {
4646
/// BTMatrices, BMatrices, ATMatrices, or AMatrices map.
4747
/// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
4848
///
49-
constexpr double G_2x2_3x3[] = {
49+
constexpr float G_2x2_3x3[] = {
5050
-1, 0, 0,
5151
1./2, -1./2, 1./2,
5252
1./2, 1./2, 1./2,
5353
0, 0, 1
5454
};
5555

56-
constexpr double GT_2x2_3x3[] = {
56+
constexpr float GT_2x2_3x3[] = {
5757
-1, 1./2, 1./2, 0,
5858
0, -1./2, 1./2, 0,
5959
0, 1./2, 1./2, 1
6060
};
6161

62-
constexpr double BT_2x2_3x3[] = {
62+
constexpr float BT_2x2_3x3[] = {
6363
-1, 0, 1, 0,
6464
0, -1, 1, 0,
6565
0, 1, 1, 0,
6666
0, -1, 0, 1
6767
};
6868

69-
constexpr double B_2x2_3x3[] = {
69+
constexpr float B_2x2_3x3[] = {
7070
-1, 0, 0, 0,
7171
0, -1, 1, -1,
7272
1, 1, 1, 0,
7373
0, 0, 0, 1
7474
};
7575

76-
constexpr double AT_2x2_3x3[] = {
76+
constexpr float AT_2x2_3x3[] = {
7777
1, 1, 1, 0,
7878
0, -1, 1, 1
7979
};
8080

81-
constexpr double A_2x2_3x3[] = {
81+
constexpr float A_2x2_3x3[] = {
8282
1, 0,
8383
1, -1,
8484
1, 1,
8585
0, 1
8686
};
8787

88-
constexpr double G_4x4_3x3[] = {
88+
constexpr float G_4x4_3x3[] = {
8989
1, 0, 0,
9090
-1./3, 1./3, -1./3,
9191
-1./3, -1./3, -1./3,
@@ -94,13 +94,13 @@ constexpr double G_4x4_3x3[] = {
9494
0, 0, 1
9595
};
9696

97-
constexpr double GT_4x4_3x3[] = {
97+
constexpr float GT_4x4_3x3[] = {
9898
1, -1./3, -1./3, 1./12, 1./12, 0,
9999
0, 1./3, -1./3, -1./6, 1./6, 0,
100100
0, -1./3, -1./3, 1./3, 1./3, 1
101101
};
102102

103-
constexpr double BT_4x4_3x3[] = {
103+
constexpr float BT_4x4_3x3[] = {
104104
1./4, 0, -5./16, 0, 1./16, 0,
105105
0, 1./4, -1./4, -1./16, 1./16, 0,
106106
0, -1./4, -1./4, 1./16, 1./16, 0,
@@ -109,7 +109,7 @@ constexpr double BT_4x4_3x3[] = {
109109
0, 1./4, 0, -5./16, 0, 1./16
110110
};
111111

112-
constexpr double B_4x4_3x3[] = {
112+
constexpr float B_4x4_3x3[] = {
113113
1./4, 0, 0, 0, 0, 0,
114114
0, 1./4, -1./4, 1./4, -1./4, 1./4,
115115
-5./16, -1./4, -1./4, -1./8, -1./8, 0,
@@ -118,14 +118,14 @@ constexpr double B_4x4_3x3[] = {
118118
0, 0, 0, 0, 0, 1./16
119119
};
120120

121-
constexpr double AT_4x4_3x3[] = {
121+
constexpr float AT_4x4_3x3[] = {
122122
1./8, 1./4, 1./4, 1./8, 1./8, 0,
123123
0, -1./4, 1./4, -1./4, 1./4, 0,
124124
0, 1./4, 1./4, 1./2, 1./2, 0,
125125
0, -1./4, 1./4, -1, 1, 1./2
126126
};
127127

128-
constexpr double A_4x4_3x3[] = {
128+
constexpr float A_4x4_3x3[] = {
129129
1./8, 0, 0, 0,
130130
1./4, -1./4, 1./4, -1./4,
131131
1./4, 1./4, 1./4, 1./4,
@@ -134,7 +134,7 @@ constexpr double A_4x4_3x3[] = {
134134
0, 0, 0, 1./2
135135
};
136136

137-
constexpr double G_2x2_5x5[] = {
137+
constexpr float G_2x2_5x5[] = {
138138
1, 0, 0, 0, 0,
139139
1./6, -1./6, 1./6, -1./6, 1./6,
140140
-1./6, -1./6, -1./6, -1./6, -1./6,
@@ -143,15 +143,15 @@ constexpr double G_2x2_5x5[] = {
143143
0, 0, 0, 0, 1
144144
};
145145

146-
constexpr double GT_2x2_5x5[] = {
146+
constexpr float GT_2x2_5x5[] = {
147147
1, 1./6, -1./6, -4./15, 1./60, 0,
148148
0, -1./6, -1./6, 2./15, 1./30, 0,
149149
0, 1./6, -1./6, -1./15, 1./15, 0,
150150
0, -1./6, -1./6, 1./30, 2./15, 0,
151151
0, 1./6, -1./6, -1./60, 4./15, 1
152152
};
153153

154-
constexpr double BT_2x2_5x5[] = {
154+
constexpr float BT_2x2_5x5[] = {
155155
1./8, 3./16, -1./4, -3./16, 1./8, 0,
156156
0, 1./8, 1./16, -5./16, 1./8, 0,
157157
0, -1./8, -5./16, -1./16, 1./8, 0,
@@ -160,7 +160,7 @@ constexpr double BT_2x2_5x5[] = {
160160
0, 1./8, 3./16, -1./4, -3./16, 1./8
161161
};
162162

163-
constexpr double B_2x2_5x5[] = {
163+
constexpr float B_2x2_5x5[] = {
164164
1./8, 0, 0, 0, 0, 0,
165165
3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
166166
-1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
@@ -169,12 +169,12 @@ constexpr double B_2x2_5x5[] = {
169169
0, 0, 0, 0, 0, 1./8
170170
};
171171

172-
constexpr double AT_2x2_5x5[] = {
172+
constexpr float AT_2x2_5x5[] = {
173173
1./2, 1, 1, 2, 1, 0,
174174
0, -1, 1, -1, 2, 1./2
175175
};
176176

177-
constexpr double A_2x2_5x5[] = {
177+
constexpr float A_2x2_5x5[] = {
178178
1./2, 0,
179179
1, -1,
180180
1, 1,
@@ -186,30 +186,27 @@ constexpr double A_2x2_5x5[] = {
186186

187187
/// Structure to keep information of constant transform matrices.
188188
struct TransformMatrix {
189-
TransformMatrix(ArrayRef<double> table, int64_t rows, int64_t cols,
189+
TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
190190
int64_t scalarFactor = 1)
191-
: table(llvm::map_to_vector(table, [](double val) { return APFloat(val); })), rows(rows), cols(cols), scalarFactor(scalarFactor) {
191+
: table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {
192192
}
193193

194-
SmallVector<APFloat> table;
194+
ArrayRef<float> table;
195195
int64_t rows;
196196
int64_t cols;
197197
int64_t scalarFactor;
198198
};
199199

200200
/// Utility function to convert constant array to arith.constant Value.
201201
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
202-
TransformMatrix transform, Type type) {
203-
assert(type.isFloat());
204-
assert(transform.table.size() == (transform.rows * transform.cols));
205-
ArrayRef<APFloat> constVec(transform.table.data(), transform.rows * transform.cols);
206-
202+
TransformMatrix transform) {
203+
assert(transform.table.size() == static_cast<size_t>(transform.rows * transform.cols));
204+
ArrayRef<float> constVec(transform.table.data(), transform.rows * transform.cols);
205+
SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
207206
return arith::ConstantOp::create(
208207
builder, loc,
209208
DenseFPElementsAttr::get(
210-
RankedTensorType::get(
211-
SmallVector<int64_t>{transform.rows, transform.cols}, type),
212-
constVec));
209+
RankedTensorType::get(shape, builder.getF32Type()), constVec));
213210
}
214211

215212
/// Extract height x width data from 4D tensors.
@@ -407,7 +404,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
407404
auto init =
408405
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
409406

410-
Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
407+
Value G = create2DTransformMatrix(builder, loc, GMatrix);
411408
// Multiply G x g.
412409
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
413410
ValueRange{G, extractFilter},
@@ -430,7 +427,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
430427
auto init =
431428
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
432429

433-
Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
430+
Value GT = create2DTransformMatrix(builder, loc, GTMatrix);
434431
// Multiply u = (G x g) x GT.
435432
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
436433
ValueRange{matmulRetValue, GT},
@@ -500,6 +497,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
500497
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
501498
auto inputType = cast<ShapedType>(input.getType());
502499
Type elementType = inputType.getElementType();
500+
// assert(elementType.isF32() && "NYI: support non-f32");
503501
auto inputShape = inputType.getShape(); // N, H, W, C
504502
int64_t inputN = inputShape[0];
505503
int64_t inputC = inputShape[3];
@@ -555,7 +553,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
555553
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
556554

557555
Value BT =
558-
create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
556+
create2DTransformMatrix(builder, loc, BTMatrix);
559557
// Multiply BT x d.
560558
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
561559
ValueRange{BT, matmulRetValue},
@@ -578,7 +576,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
578576
auto init =
579577
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
580578
Value B =
581-
create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
579+
create2DTransformMatrix(builder, loc, BMatrix);
582580
// Multiply v = (BT x d) x B.
583581
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
584582
ValueRange{matmulRetValue, B},
@@ -723,6 +721,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
723721
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
724722
auto valueType = cast<ShapedType>(value.getType());
725723
Type elementType = valueType.getElementType();
724+
// assert(elementType.isF32() && "NYI: support non-f32");
726725
auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
727726
int64_t valueH = valueShape[0];
728727
int64_t valueW = valueShape[1];
@@ -786,7 +785,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
786785
init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
787786
}
788787

789-
Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
788+
Value AT = create2DTransformMatrix(builder, loc, ATMatrix);
790789
// Multiply AT x m.
791790
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
792791
ValueRange{AT, matmulRetValue},
@@ -805,7 +804,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
805804
init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
806805
}
807806

808-
Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
807+
Value A = create2DTransformMatrix(builder, loc, AMatrix);
809808
// Multiply y = (AT x m) x A.
810809
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
811810
ValueRange{matmulRetValue, A},

0 commit comments

Comments
 (0)