@@ -46,46 +46,46 @@ namespace {
46
46
// / BTMatrices, BMatrices, ATMatrices, or AMatrices map.
47
47
// / 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
48
48
// /
49
- constexpr double G_2x2_3x3[] = {
49
+ constexpr float G_2x2_3x3[] = {
50
50
-1 , 0 , 0 ,
51
51
1 ./2 , -1 ./2 , 1 ./2 ,
52
52
1 ./2 , 1 ./2 , 1 ./2 ,
53
53
0 , 0 , 1
54
54
};
55
55
56
- constexpr double GT_2x2_3x3[] = {
56
+ constexpr float GT_2x2_3x3[] = {
57
57
-1 , 1 ./2 , 1 ./2 , 0 ,
58
58
0 , -1 ./2 , 1 ./2 , 0 ,
59
59
0 , 1 ./2 , 1 ./2 , 1
60
60
};
61
61
62
- constexpr double BT_2x2_3x3[] = {
62
+ constexpr float BT_2x2_3x3[] = {
63
63
-1 , 0 , 1 , 0 ,
64
64
0 , -1 , 1 , 0 ,
65
65
0 , 1 , 1 , 0 ,
66
66
0 , -1 , 0 , 1
67
67
};
68
68
69
- constexpr double B_2x2_3x3[] = {
69
+ constexpr float B_2x2_3x3[] = {
70
70
-1 , 0 , 0 , 0 ,
71
71
0 , -1 , 1 , -1 ,
72
72
1 , 1 , 1 , 0 ,
73
73
0 , 0 , 0 , 1
74
74
};
75
75
76
- constexpr double AT_2x2_3x3[] = {
76
+ constexpr float AT_2x2_3x3[] = {
77
77
1 , 1 , 1 , 0 ,
78
78
0 , -1 , 1 , 1
79
79
};
80
80
81
- constexpr double A_2x2_3x3[] = {
81
+ constexpr float A_2x2_3x3[] = {
82
82
1 , 0 ,
83
83
1 , -1 ,
84
84
1 , 1 ,
85
85
0 , 1
86
86
};
87
87
88
- constexpr double G_4x4_3x3[] = {
88
+ constexpr float G_4x4_3x3[] = {
89
89
1 , 0 , 0 ,
90
90
-1 ./3 , 1 ./3 , -1 ./3 ,
91
91
-1 ./3 , -1 ./3 , -1 ./3 ,
@@ -94,13 +94,13 @@ constexpr double G_4x4_3x3[] = {
94
94
0 , 0 , 1
95
95
};
96
96
97
- constexpr double GT_4x4_3x3[] = {
97
+ constexpr float GT_4x4_3x3[] = {
98
98
1 , -1 ./3 , -1 ./3 , 1 ./12 , 1 ./12 , 0 ,
99
99
0 , 1 ./3 , -1 ./3 , -1 ./6 , 1 ./6 , 0 ,
100
100
0 , -1 ./3 , -1 ./3 , 1 ./3 , 1 ./3 , 1
101
101
};
102
102
103
- constexpr double BT_4x4_3x3[] = {
103
+ constexpr float BT_4x4_3x3[] = {
104
104
1 ./4 , 0 , -5 ./16 , 0 , 1 ./16 , 0 ,
105
105
0 , 1 ./4 , -1 ./4 , -1 ./16 , 1 ./16 , 0 ,
106
106
0 , -1 ./4 , -1 ./4 , 1 ./16 , 1 ./16 , 0 ,
@@ -109,7 +109,7 @@ constexpr double BT_4x4_3x3[] = {
109
109
0 , 1 ./4 , 0 , -5 ./16 , 0 , 1 ./16
110
110
};
111
111
112
- constexpr double B_4x4_3x3[] = {
112
+ constexpr float B_4x4_3x3[] = {
113
113
1 ./4 , 0 , 0 , 0 , 0 , 0 ,
114
114
0 , 1 ./4 , -1 ./4 , 1 ./4 , -1 ./4 , 1 ./4 ,
115
115
-5 ./16 , -1 ./4 , -1 ./4 , -1 ./8 , -1 ./8 , 0 ,
@@ -118,14 +118,14 @@ constexpr double B_4x4_3x3[] = {
118
118
0 , 0 , 0 , 0 , 0 , 1 ./16
119
119
};
120
120
121
- constexpr double AT_4x4_3x3[] = {
121
+ constexpr float AT_4x4_3x3[] = {
122
122
1 ./8 , 1 ./4 , 1 ./4 , 1 ./8 , 1 ./8 , 0 ,
123
123
0 , -1 ./4 , 1 ./4 , -1 ./4 , 1 ./4 , 0 ,
124
124
0 , 1 ./4 , 1 ./4 , 1 ./2 , 1 ./2 , 0 ,
125
125
0 , -1 ./4 , 1 ./4 , -1 , 1 , 1 ./2
126
126
};
127
127
128
- constexpr double A_4x4_3x3[] = {
128
+ constexpr float A_4x4_3x3[] = {
129
129
1 ./8 , 0 , 0 , 0 ,
130
130
1 ./4 , -1 ./4 , 1 ./4 , -1 ./4 ,
131
131
1 ./4 , 1 ./4 , 1 ./4 , 1 ./4 ,
@@ -134,7 +134,7 @@ constexpr double A_4x4_3x3[] = {
134
134
0 , 0 , 0 , 1 ./2
135
135
};
136
136
137
- constexpr double G_2x2_5x5[] = {
137
+ constexpr float G_2x2_5x5[] = {
138
138
1 , 0 , 0 , 0 , 0 ,
139
139
1 ./6 , -1 ./6 , 1 ./6 , -1 ./6 , 1 ./6 ,
140
140
-1 ./6 , -1 ./6 , -1 ./6 , -1 ./6 , -1 ./6 ,
@@ -143,15 +143,15 @@ constexpr double G_2x2_5x5[] = {
143
143
0 , 0 , 0 , 0 , 1
144
144
};
145
145
146
- constexpr double GT_2x2_5x5[] = {
146
+ constexpr float GT_2x2_5x5[] = {
147
147
1 , 1 ./6 , -1 ./6 , -4 ./15 , 1 ./60 , 0 ,
148
148
0 , -1 ./6 , -1 ./6 , 2 ./15 , 1 ./30 , 0 ,
149
149
0 , 1 ./6 , -1 ./6 , -1 ./15 , 1 ./15 , 0 ,
150
150
0 , -1 ./6 , -1 ./6 , 1 ./30 , 2 ./15 , 0 ,
151
151
0 , 1 ./6 , -1 ./6 , -1 ./60 , 4 ./15 , 1
152
152
};
153
153
154
- constexpr double BT_2x2_5x5[] = {
154
+ constexpr float BT_2x2_5x5[] = {
155
155
1 ./8 , 3 ./16 , -1 ./4 , -3 ./16 , 1 ./8 , 0 ,
156
156
0 , 1 ./8 , 1 ./16 , -5 ./16 , 1 ./8 , 0 ,
157
157
0 , -1 ./8 , -5 ./16 , -1 ./16 , 1 ./8 , 0 ,
@@ -160,7 +160,7 @@ constexpr double BT_2x2_5x5[] = {
160
160
0 , 1 ./8 , 3 ./16 , -1 ./4 , -3 ./16 , 1 ./8
161
161
};
162
162
163
- constexpr double B_2x2_5x5[] = {
163
+ constexpr float B_2x2_5x5[] = {
164
164
1 ./8 , 0 , 0 , 0 , 0 , 0 ,
165
165
3 ./16 , 1 ./8 , -1 ./8 , 1 ./4 , -1 ./8 , 1 ./8 ,
166
166
-1 ./4 , 1 ./16 , -5 ./16 , -1 ./8 , -1 ./4 , 3 ./16 ,
@@ -169,12 +169,12 @@ constexpr double B_2x2_5x5[] = {
169
169
0 , 0 , 0 , 0 , 0 , 1 ./8
170
170
};
171
171
172
- constexpr double AT_2x2_5x5[] = {
172
+ constexpr float AT_2x2_5x5[] = {
173
173
1 ./2 , 1 , 1 , 2 , 1 , 0 ,
174
174
0 , -1 , 1 , -1 , 2 , 1 ./2
175
175
};
176
176
177
- constexpr double A_2x2_5x5[] = {
177
+ constexpr float A_2x2_5x5[] = {
178
178
1 ./2 , 0 ,
179
179
1 , -1 ,
180
180
1 , 1 ,
@@ -186,30 +186,27 @@ constexpr double A_2x2_5x5[] = {
186
186
187
187
// / Structure to keep information of constant transform matrices.
188
188
struct TransformMatrix {
189
- TransformMatrix (ArrayRef<double > table, int64_t rows, int64_t cols,
189
+ TransformMatrix (ArrayRef<float > table, int64_t rows, int64_t cols,
190
190
int64_t scalarFactor = 1 )
191
- : table(llvm::map_to_vector( table, []( double val) { return APFloat (val); }) ), rows(rows), cols(cols), scalarFactor(scalarFactor) {
191
+ : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {
192
192
}
193
193
194
- SmallVector<APFloat > table;
194
+ ArrayRef< float > table;
195
195
int64_t rows;
196
196
int64_t cols;
197
197
int64_t scalarFactor;
198
198
};
199
199
200
200
// / Utility function to convert constant array to arith.constant Value.
201
201
Value create2DTransformMatrix (OpBuilder &builder, Location loc,
202
- TransformMatrix transform, Type type) {
203
- assert (type.isFloat ());
204
- assert (transform.table .size () == (transform.rows * transform.cols ));
205
- ArrayRef<APFloat> constVec (transform.table .data (), transform.rows * transform.cols );
206
-
202
+ TransformMatrix transform) {
203
+ assert (transform.table .size () == static_cast <size_t >(transform.rows * transform.cols ));
204
+ ArrayRef<float > constVec (transform.table .data (), transform.rows * transform.cols );
205
+ SmallVector<int64_t , 2 > shape{transform.rows , transform.cols };
207
206
return arith::ConstantOp::create (
208
207
builder, loc,
209
208
DenseFPElementsAttr::get (
210
- RankedTensorType::get (
211
- SmallVector<int64_t >{transform.rows , transform.cols }, type),
212
- constVec));
209
+ RankedTensorType::get (shape, builder.getF32Type ()), constVec));
213
210
}
214
211
215
212
// / Extract height x width data from 4D tensors.
@@ -407,7 +404,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
407
404
auto init =
408
405
linalg::FillOp::create (builder, loc, zero, empty).getResult (0 );
409
406
410
- Value G = create2DTransformMatrix (builder, loc, GMatrix, elementType );
407
+ Value G = create2DTransformMatrix (builder, loc, GMatrix);
411
408
// Multiply G x g.
412
409
auto matmulOp = linalg::MatmulOp::create (builder, loc, matmulType,
413
410
ValueRange{G, extractFilter},
@@ -430,7 +427,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
430
427
auto init =
431
428
linalg::FillOp::create (builder, loc, zero, empty).getResult (0 );
432
429
433
- Value GT = create2DTransformMatrix (builder, loc, GTMatrix, elementType );
430
+ Value GT = create2DTransformMatrix (builder, loc, GTMatrix);
434
431
// Multiply u = (G x g) x GT.
435
432
auto matmulOp = linalg::MatmulOp::create (builder, loc, matmulType,
436
433
ValueRange{matmulRetValue, GT},
@@ -500,6 +497,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
500
497
std::tie (m, r) = getFmrFromWinogradConv2DFmr (fmr);
501
498
auto inputType = cast<ShapedType>(input.getType ());
502
499
Type elementType = inputType.getElementType ();
500
+ // assert(elementType.isF32() && "NYI: support non-f32");
503
501
auto inputShape = inputType.getShape (); // N, H, W, C
504
502
int64_t inputN = inputShape[0 ];
505
503
int64_t inputC = inputShape[3 ];
@@ -555,7 +553,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
555
553
linalg::FillOp::create (builder, loc, zero, empty).getResult (0 );
556
554
557
555
Value BT =
558
- create2DTransformMatrix (builder, loc, BTMatrix, builder. getF32Type () );
556
+ create2DTransformMatrix (builder, loc, BTMatrix);
559
557
// Multiply BT x d.
560
558
auto matmulOp = linalg::MatmulOp::create (builder, loc, matmulType,
561
559
ValueRange{BT, matmulRetValue},
@@ -578,7 +576,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
578
576
auto init =
579
577
linalg::FillOp::create (builder, loc, zero, empty).getResult (0 );
580
578
Value B =
581
- create2DTransformMatrix (builder, loc, BMatrix, builder. getF32Type () );
579
+ create2DTransformMatrix (builder, loc, BMatrix);
582
580
// Multiply v = (BT x d) x B.
583
581
auto matmulOp = linalg::MatmulOp::create (builder, loc, matmulType,
584
582
ValueRange{matmulRetValue, B},
@@ -723,6 +721,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
723
721
std::tie (m, r) = getFmrFromWinogradConv2DFmr (fmr);
724
722
auto valueType = cast<ShapedType>(value.getType ());
725
723
Type elementType = valueType.getElementType ();
724
+ // assert(elementType.isF32() && "NYI: support non-f32");
726
725
auto valueShape = valueType.getShape (); // H, W, TileH, TileW, N, F
727
726
int64_t valueH = valueShape[0 ];
728
727
int64_t valueW = valueShape[1 ];
@@ -786,7 +785,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
786
785
init = linalg::FillOp::create (builder, loc, zero, empty).getResult (0 );
787
786
}
788
787
789
- Value AT = create2DTransformMatrix (builder, loc, ATMatrix, elementType );
788
+ Value AT = create2DTransformMatrix (builder, loc, ATMatrix);
790
789
// Multiply AT x m.
791
790
auto matmulOp = linalg::MatmulOp::create (builder, loc, matmulType,
792
791
ValueRange{AT, matmulRetValue},
@@ -805,7 +804,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
805
804
init = linalg::FillOp::create (builder, loc, zero, empty).getResult (0 );
806
805
}
807
806
808
- Value A = create2DTransformMatrix (builder, loc, AMatrix, elementType );
807
+ Value A = create2DTransformMatrix (builder, loc, AMatrix);
809
808
// Multiply y = (AT x m) x A.
810
809
auto matmulOp = linalg::MatmulOp::create (builder, loc, matmulType,
811
810
ValueRange{matmulRetValue, A},
0 commit comments