@@ -208,7 +208,7 @@ void testIm2colCPU(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
208
208
209
209
void benchIm2col (int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
210
210
PREPARE_IM2COL_CPU;
211
- constexpr int repeat = 30 ;
211
+ constexpr int repeat = 100 ;
212
212
auto GetCurrentMs = []() -> double {
213
213
struct timeval time;
214
214
gettimeofday (&time, NULL );
@@ -231,17 +231,39 @@ void benchIm2col(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
231
231
}
232
232
233
233
TEST (math, im2col_cputest) {
234
- testIm2colCPU (/* ic*/ 2 , /* ih*/ 5 , /* iw*/ 4 , /* fh*/ 3 , /* fw*/ 3 , /* ph*/ 0 ,
235
- /* pw*/ 0 );
236
- testIm2colCPU (/* ic*/ 2 , /* ih*/ 5 , /* iw*/ 4 , /* fh*/ 3 , /* fw*/ 3 , /* ph*/ 1 ,
237
- /* pw*/ 1 );
234
+ // padding_h == padding_w
235
+ for (int p = 0 ; p < 4 ; ++p) {
236
+ // width == height
237
+ testIm2colCPU (/* ic*/ 2 , /* ih*/ 5 , /* iw*/ 5 , /* fh*/ 4 , /* fw*/ 4 , /* ph*/ p,
238
+ /* pw*/ p);
239
+ testIm2colCPU (/* ic*/ 2 , /* ih*/ 4 , /* iw*/ 4 , /* fh*/ 3 , /* fw*/ 3 , /* ph*/ p,
240
+ /* pw*/ p);
241
+ testIm2colCPU (/* ic*/ 2 , /* ih*/ 4 , /* iw*/ 4 , /* fh*/ 2 , /* fw*/ 2 , /* ph*/ p,
242
+ /* pw*/ p);
238
243
239
- benchIm2col (/* ic*/ 3 , /* ih*/ 224 , /* iw*/ 224 , /* fh*/ 3 , /* fw*/ 3 , /* ph*/ 1 ,
240
- /* pw*/ 1 );
244
+ // height != width
245
+ testIm2colCPU (/* ic*/ 2 , /* ih*/ 5 , /* iw*/ 4 , /* fh*/ 2 , /* fw*/ 3 , /* ph*/ p,
246
+ /* pw*/ p);
247
+
248
+ // filter == 1
249
+ testIm2colCPU (/* ic*/ 3 , /* ih*/ 4 , /* iw*/ 4 , /* fh*/ 1 , /* fw*/ 1 , /* ph*/ p,
250
+ /* pw*/ p);
251
+ testIm2colCPU (/* ic*/ 3 , /* ih*/ 3 , /* iw*/ 4 , /* fh*/ 1 , /* fw*/ 1 , /* ph*/ p,
252
+ /* pw*/ p);
253
+ }
254
+ // padding_h != padding_w
255
+ testIm2colCPU (/* ic*/ 2 , /* ih*/ 4 , /* iw*/ 4 , /* fh*/ 2 , /* fw*/ 3 , /* ph*/ 1 ,
256
+ /* pw*/ 2 );
257
+
258
+ // benchmark
259
+ LOG (INFO) << " padding == 0" ;
241
260
benchIm2col (/* ic*/ 3 , /* ih*/ 224 , /* iw*/ 224 , /* fh*/ 3 , /* fw*/ 3 , /* ph*/ 0 ,
242
261
/* pw*/ 0 );
243
- benchIm2col (/* ic*/ 3 , /* ih*/ 224 , /* iw*/ 224 , /* fh*/ 5 , /* fw*/ 5 , /* ph*/ 1 ,
244
- /* pw*/ 1 );
245
262
benchIm2col (/* ic*/ 3 , /* ih*/ 224 , /* iw*/ 224 , /* fh*/ 5 , /* fw*/ 5 , /* ph*/ 0 ,
246
263
/* pw*/ 0 );
264
+ LOG (INFO) << " padding == 1" ;
265
+ benchIm2col (/* ic*/ 3 , /* ih*/ 224 , /* iw*/ 224 , /* fh*/ 3 , /* fw*/ 3 , /* ph*/ 1 ,
266
+ /* pw*/ 1 );
267
+ benchIm2col (/* ic*/ 3 , /* ih*/ 224 , /* iw*/ 224 , /* fh*/ 5 , /* fw*/ 5 , /* ph*/ 1 ,
268
+ /* pw*/ 1 );
247
269
}
0 commit comments