@@ -24,19 +24,18 @@ __kernel void pool_max(__read_only image2d_t input,
24
24
__private const int ksize_w ,
25
25
__private const int stride_h ,
26
26
__private const int stride_w ,
27
- __private const int pad_top ,
28
- __private const int pad_left ) {
27
+ __private const int4 pad ) {
29
28
const int out_c = get_global_id (0 );
30
29
const int out_w = get_global_id (1 );
31
30
const int out_nh = get_global_id (2 );
32
31
const int out_n = out_nh / out_height ;
33
32
const int out_h = out_nh % out_height ;
34
33
35
- int start_h = out_h * stride_h - pad_top ;
34
+ int start_h = out_h * stride_h - ( pad . x - pad . y ) ;
36
35
int end_h = min (start_h + ksize_h , in_height );
37
36
start_h = max (start_h , 0 );
38
37
39
- int start_w = out_w * stride_w - pad_left ;
38
+ int start_w = out_w * stride_w - ( pad . w - pad . z ) ;
40
39
int end_w = min (start_w + ksize_w , in_width );
41
40
start_w = max (start_w , 0 );
42
41
@@ -65,19 +64,18 @@ __kernel void pool_avg(__read_only image2d_t input,
65
64
__private const int ksize_w ,
66
65
__private const int stride_h ,
67
66
__private const int stride_w ,
68
- __private const int pad_top ,
69
- __private const int pad_left ) {
67
+ __private const int4 pad ) {
70
68
const int out_c = get_global_id (0 );
71
69
const int out_w = get_global_id (1 );
72
70
const int out_nh = get_global_id (2 );
73
71
const int out_n = out_nh / out_height ;
74
72
const int out_h = out_nh % out_height ;
75
73
76
- int start_h = out_h * stride_h - pad_top ;
74
+ int start_h = out_h * stride_h - pad . x ;
77
75
int end_h = min (start_h + ksize_h , in_height );
78
76
start_h = max (start_h , 0 );
79
77
80
- int start_w = out_w * stride_w - pad_left ;
78
+ int start_w = out_w * stride_w - pad . z ;
81
79
int end_w = min (start_w + ksize_w , in_width );
82
80
start_w = max (start_w , 0 );
83
81
@@ -96,7 +94,7 @@ __kernel void pool_avg(__read_only image2d_t input,
96
94
div = (CL_DTYPE )((end_h - start_h )* (end_w - start_w ));
97
95
#else
98
96
div = (CL_DTYPE )(ksize_w * ksize_h );
99
- #endif
97
+ #endif
100
98
CL_DTYPE4 avg = sum / div ;
101
99
const int pos_out_x = mad24 (out_c , out_width , out_w );
102
100
WRITE_IMG_TYPE (CL_DTYPE_CHAR , output , (int2 )(pos_out_x , out_nh ), avg );
@@ -112,8 +110,7 @@ __kernel void pool_avg_global(__read_only image2d_t input,
112
110
__private const int ksize_w ,
113
111
__private const int stride_h ,
114
112
__private const int stride_w ,
115
- __private const int pad_top ,
116
- __private const int pad_left ) {
113
+ __private const int4 pad ) {
117
114
const int out_c = get_global_id (0 );
118
115
const int out_w = get_global_id (1 ); // =1
119
116
const int out_nh = get_global_id (2 ); // = n*1
@@ -182,8 +179,7 @@ __kernel void pool_max_global(__read_only image2d_t input,
182
179
__private const int ksize_w ,
183
180
__private const int stride_h ,
184
181
__private const int stride_w ,
185
- __private const int pad_top ,
186
- __private const int pad_left ) {
182
+ __private const int4 pad ) {
187
183
const int out_c = get_global_id (0 );
188
184
const int out_w = get_global_id (1 ); // =1
189
185
const int out_nh = get_global_id (2 ); // = n*1
0 commit comments