@@ -46,46 +46,46 @@ namespace {
46
46
// / BTMatrices, BMatrices, ATMatrices, or AMatrices map.
47
47
// / 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
48
48
// /
49
- constexpr float G_2x2_3x3[] = {
49
+ constexpr double G_2x2_3x3[] = {
50
50
-1 , 0 , 0 ,
51
51
1 ./2 , -1 ./2 , 1 ./2 ,
52
52
1 ./2 , 1 ./2 , 1 ./2 ,
53
53
0 , 0 , 1
54
54
};
55
55
56
- constexpr float GT_2x2_3x3[] = {
56
+ constexpr double GT_2x2_3x3[] = {
57
57
-1 , 1 ./2 , 1 ./2 , 0 ,
58
58
0 , -1 ./2 , 1 ./2 , 0 ,
59
59
0 , 1 ./2 , 1 ./2 , 1
60
60
};
61
61
62
- constexpr float BT_2x2_3x3[] = {
62
+ constexpr double BT_2x2_3x3[] = {
63
63
-1 , 0 , 1 , 0 ,
64
64
0 , -1 , 1 , 0 ,
65
65
0 , 1 , 1 , 0 ,
66
66
0 , -1 , 0 , 1
67
67
};
68
68
69
- constexpr float B_2x2_3x3[] = {
69
+ constexpr double B_2x2_3x3[] = {
70
70
-1 , 0 , 0 , 0 ,
71
71
0 , -1 , 1 , -1 ,
72
72
1 , 1 , 1 , 0 ,
73
73
0 , 0 , 0 , 1
74
74
};
75
75
76
- constexpr float AT_2x2_3x3[] = {
76
+ constexpr double AT_2x2_3x3[] = {
77
77
1 , 1 , 1 , 0 ,
78
78
0 , -1 , 1 , 1
79
79
};
80
80
81
- constexpr float A_2x2_3x3[] = {
81
+ constexpr double A_2x2_3x3[] = {
82
82
1 , 0 ,
83
83
1 , -1 ,
84
84
1 , 1 ,
85
85
0 , 1
86
86
};
87
87
88
- constexpr float G_4x4_3x3[] = {
88
+ constexpr double G_4x4_3x3[] = {
89
89
1 , 0 , 0 ,
90
90
-1 ./3 , 1 ./3 , -1 ./3 ,
91
91
-1 ./3 , -1 ./3 , -1 ./3 ,
@@ -94,13 +94,13 @@ constexpr float G_4x4_3x3[] = {
94
94
0 , 0 , 1
95
95
};
96
96
97
- constexpr float GT_4x4_3x3[] = {
97
+ constexpr double GT_4x4_3x3[] = {
98
98
1 , -1 ./3 , -1 ./3 , 1 ./12 , 1 ./12 , 0 ,
99
99
0 , 1 ./3 , -1 ./3 , -1 ./6 , 1 ./6 , 0 ,
100
100
0 , -1 ./3 , -1 ./3 , 1 ./3 , 1 ./3 , 1
101
101
};
102
102
103
- constexpr float BT_4x4_3x3[] = {
103
+ constexpr double BT_4x4_3x3[] = {
104
104
1 ./4 , 0 , -5 ./16 , 0 , 1 ./16 , 0 ,
105
105
0 , 1 ./4 , -1 ./4 , -1 ./16 , 1 ./16 , 0 ,
106
106
0 , -1 ./4 , -1 ./4 , 1 ./16 , 1 ./16 , 0 ,
@@ -109,7 +109,7 @@ constexpr float BT_4x4_3x3[] = {
109
109
0 , 1 ./4 , 0 , -5 ./16 , 0 , 1 ./16
110
110
};
111
111
112
- constexpr float B_4x4_3x3[] = {
112
+ constexpr double B_4x4_3x3[] = {
113
113
1 ./4 , 0 , 0 , 0 , 0 , 0 ,
114
114
0 , 1 ./4 , -1 ./4 , 1 ./4 , -1 ./4 , 1 ./4 ,
115
115
-5 ./16 , -1 ./4 , -1 ./4 , -1 ./8 , -1 ./8 , 0 ,
@@ -118,14 +118,14 @@ constexpr float B_4x4_3x3[] = {
118
118
0 , 0 , 0 , 0 , 0 , 1 ./16
119
119
};
120
120
121
- constexpr float AT_4x4_3x3[] = {
121
+ constexpr double AT_4x4_3x3[] = {
122
122
1 ./8 , 1 ./4 , 1 ./4 , 1 ./8 , 1 ./8 , 0 ,
123
123
0 , -1 ./4 , 1 ./4 , -1 ./4 , 1 ./4 , 0 ,
124
124
0 , 1 ./4 , 1 ./4 , 1 ./2 , 1 ./2 , 0 ,
125
125
0 , -1 ./4 , 1 ./4 , -1 , 1 , 1 ./2
126
126
};
127
127
128
- constexpr float A_4x4_3x3[] = {
128
+ constexpr double A_4x4_3x3[] = {
129
129
1 ./8 , 0 , 0 , 0 ,
130
130
1 ./4 , -1 ./4 , 1 ./4 , -1 ./4 ,
131
131
1 ./4 , 1 ./4 , 1 ./4 , 1 ./4 ,
@@ -134,7 +134,7 @@ constexpr float A_4x4_3x3[] = {
134
134
0 , 0 , 0 , 1 ./2
135
135
};
136
136
137
- constexpr float G_2x2_5x5[] = {
137
+ constexpr double G_2x2_5x5[] = {
138
138
1 , 0 , 0 , 0 , 0 ,
139
139
1 ./6 , -1 ./6 , 1 ./6 , -1 ./6 , 1 ./6 ,
140
140
-1 ./6 , -1 ./6 , -1 ./6 , -1 ./6 , -1 ./6 ,
@@ -143,15 +143,15 @@ constexpr float G_2x2_5x5[] = {
143
143
0 , 0 , 0 , 0 , 1
144
144
};
145
145
146
- constexpr float GT_2x2_5x5[] = {
146
+ constexpr double GT_2x2_5x5[] = {
147
147
1 , 1 ./6 , -1 ./6 , -4 ./15 , 1 ./60 , 0 ,
148
148
0 , -1 ./6 , -1 ./6 , 2 ./15 , 1 ./30 , 0 ,
149
149
0 , 1 ./6 , -1 ./6 , -1 ./15 , 1 ./15 , 0 ,
150
150
0 , -1 ./6 , -1 ./6 , 1 ./30 , 2 ./15 , 0 ,
151
151
0 , 1 ./6 , -1 ./6 , -1 ./60 , 4 ./15 , 1
152
152
};
153
153
154
- constexpr float BT_2x2_5x5[] = {
154
+ constexpr double BT_2x2_5x5[] = {
155
155
1 ./8 , 3 ./16 , -1 ./4 , -3 ./16 , 1 ./8 , 0 ,
156
156
0 , 1 ./8 , 1 ./16 , -5 ./16 , 1 ./8 , 0 ,
157
157
0 , -1 ./8 , -5 ./16 , -1 ./16 , 1 ./8 , 0 ,
@@ -160,7 +160,7 @@ constexpr float BT_2x2_5x5[] = {
160
160
0 , 1 ./8 , 3 ./16 , -1 ./4 , -3 ./16 , 1 ./8
161
161
};
162
162
163
- constexpr float B_2x2_5x5[] = {
163
+ constexpr double B_2x2_5x5[] = {
164
164
1 ./8 , 0 , 0 , 0 , 0 , 0 ,
165
165
3 ./16 , 1 ./8 , -1 ./8 , 1 ./4 , -1 ./8 , 1 ./8 ,
166
166
-1 ./4 , 1 ./16 , -5 ./16 , -1 ./8 , -1 ./4 , 3 ./16 ,
@@ -169,12 +169,12 @@ constexpr float B_2x2_5x5[] = {
169
169
0 , 0 , 0 , 0 , 0 , 1 ./8
170
170
};
171
171
172
- constexpr float AT_2x2_5x5[] = {
172
+ constexpr double AT_2x2_5x5[] = {
173
173
1 ./2 , 1 , 1 , 2 , 1 , 0 ,
174
174
0 , -1 , 1 , -1 , 2 , 1 ./2
175
175
};
176
176
177
- constexpr float A_2x2_5x5[] = {
177
+ constexpr double A_2x2_5x5[] = {
178
178
1 ./2 , 0 ,
179
179
1 , -1 ,
180
180
1 , 1 ,
@@ -186,11 +186,12 @@ constexpr float A_2x2_5x5[] = {
186
186
187
187
// / Structure to keep information of constant transform matrices.
188
188
struct TransformMatrix {
189
- TransformMatrix (const float * table, int64_t rows, int64_t cols,
189
+ TransformMatrix (ArrayRef< double > table, int64_t rows, int64_t cols,
190
190
int64_t scalarFactor = 1 )
191
- : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
191
+ : table(llvm::map_to_vector(table, [](double val) { return APFloat (val); })), rows(rows), cols(cols), scalarFactor(scalarFactor) {
192
+ }
192
193
193
- const float * table;
194
+ SmallVector<APFloat> table;
194
195
int64_t rows;
195
196
int64_t cols;
196
197
int64_t scalarFactor;
@@ -199,7 +200,9 @@ struct TransformMatrix {
199
200
// / Utility function to convert constant array to arith.constant Value.
200
201
Value create2DTransformMatrix (OpBuilder &builder, Location loc,
201
202
TransformMatrix transform, Type type) {
202
- ArrayRef<float > constVec (transform.table , transform.rows * transform.cols );
203
+ assert (type.isFloat ());
204
+ assert (transform.table .size () == (transform.rows * transform.cols ));
205
+ ArrayRef<APFloat> constVec (transform.table .data (), transform.rows * transform.cols );
203
206
204
207
return arith::ConstantOp::create (
205
208
builder, loc,
0 commit comments