1818
1919layout (std430) buffer ;
2020
21- layout (set = 0 , binding = 0 , ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
22- layout (set = 0 , binding = 1 ) uniform PRECISION sampler3D image_in;
23- layout (set = 0 , binding = 2 ) uniform PRECISION sampler3D kernel_in;
24- layout (set = 0 , binding = 3 ) uniform PRECISION sampler3D bias_in;
25-
26- layout (set = 0 , binding = 4 ) uniform PRECISION restrict OutLimits {
27- ivec3 out_limits;
28- };
29-
30- layout (set = 0 , binding = 5 ) uniform PRECISION restrict InSizes {
31- ivec4 in_sizes;
32- };
33-
34- layout (set = 0 , binding = 6 ) uniform PRECISION restrict Params {
35- int kernel_size;
36- int stride;
37- int padding;
38- int dilation;
39- int in_group_size;
40- int out_group_size;
41- };
42-
43- layout (set = 0 , binding = 7 ) uniform PRECISION restrict OutputParams {
44- float out_min;
45- float out_max;
46- };
21+ ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
22+ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
23+ ${layout_declare_tensor(B, "r", "kernel_in", DTYPE, STORAGE)}
24+ ${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)}
25+
26+ ${layout_declare_ubo(B, "ivec3 ", "out_limits")}
27+ ${layout_declare_ubo(B, "ivec4 ", "in_sizes")}
28+
29+ ${layout_declare_ubo(B, "ivec4 ", "out_axis_map")}
30+ ${layout_declare_ubo(B, "ivec4 ", "in_axis_map")}
31+ ${layout_declare_ubo(B, "ivec4 ", "kernel_axis_map")}
32+ ${layout_declare_ubo(B, "ivec4 ", "bias_axis_map")}
33+
34+ ${layout_declare_ubo(B,"int ", "kernel_size", "int ", "stride", "int ", "padding", "int ", "dilation", "int ", "in_group_size", "int ", "out_group_size")}
35+
36+ ${layout_declare_ubo(B, "float ", "out_min", "float ", "out_max")}
4737
4838layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4939
@@ -67,9 +57,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
6757// shader invocations, where each invocation computes 1 result. But that
6858// performs worse.
6959void main() {
70- const ivec3 pos = ivec3 (gl_GlobalInvocationID);
60+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
7161
72- if (any (greaterThanEqual (pos , out_limits))) {
62+ if (any (greaterThanEqual (lpos , out_limits))) {
7363 return ;
7464 }
7565
@@ -78,8 +68,8 @@ void main() {
7868
7969 // "out_c" is the output's channel index where we write our result.
8070 // Across shader invocations, this is the only value that varies.
81- int out_c = pos .y;
82- vec4 bias = texelFetch (bias_in, ivec3 (out_c, 0 , 0 ), 0 );
71+ int out_c = lpos .y;
72+ VEC4_T bias = load_texel_lpos (bias_in, ivec3 (out_c, 0 , 0 ), bias_axis_map );
8373
8474 // "in_c" tracks the input's channel start index.
8575 // We iterate over the input group that corresponds to the output group.
@@ -98,7 +88,7 @@ void main() {
9888 int out_l = 0 ;
9989
10090 for (int in_l = l_start; in_l < l_end; in_l += stride, ++ out_l) {
101- vec4 sum = vec4 (0 );
91+ VEC4_T sum = VEC4_T (0 );
10292
10393 for (int in_c = c_start; in_c < c_end; ++ in_c) {
10494 // "k" tracks the kernel's index for our input-kernel computation.
@@ -107,25 +97,25 @@ void main() {
10797 for (int k = 0 ; k < kernel_size; k += 4 ) {
10898 // Since the weight tensor is width-packed, which is along the length
10999 // dimension, we can batch-read four elements at a time.
110- const ivec3 w_pos = ivec3 (k / 4 , in_c % in_group_size, out_c);
111- const vec4 weight = texelFetch (kernel_in, w_pos, 0 );
100+ const ivec3 w_lpos = ivec3 (k / 4 , in_c % in_group_size, out_c);
101+ const VEC4_T weight = load_texel_lpos (kernel_in, w_lpos, kernel_axis_map );
112102
113- const ivec3 in_pos_0 = ivec3 (in_l + k * dilation, in_c, n / 4 );
114- sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0 ), sum);
103+ ivec3 in_pos = lpos_to_pos( ivec3 (in_l + k * dilation, in_c, n / 4 ), in_axis_map );
104+ sum = fma(weight.xxxx, load_texel(t_in, in_pos ), sum);
115105
116- const ivec3 in_pos_1 = ivec3 (in_l + (k + 1 ) * dilation, in_c, n / 4 ) ;
117- sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0 ), sum);
106+ in_pos[in_axis_map.x] += dilation;
107+ sum = fma(weight.yyyy, load_texel(t_in, in_pos ), sum);
118108
119- const ivec3 in_pos_2 = ivec3 (in_l + (k + 2 ) * dilation, in_c, n / 4 ) ;
120- sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0 ), sum);
109+ in_pos[in_axis_map.x] += dilation;
110+ sum = fma(weight.zzzz, load_texel(t_in, in_pos ), sum);
121111
122- const ivec3 in_pos_3 = ivec3 (in_l + (k + 3 ) * dilation, in_c, n / 4 ) ;
123- sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0 ), sum);
112+ in_pos[in_axis_map.x] += dilation;
113+ sum = fma(weight.wwww, load_texel(t_in, in_pos ), sum);
124114 }
125115 }
126116
127- ivec3 out_pos = ivec3 (out_l, out_c, n / 4 );
128- imageStore(image_out, out_pos , op(sum + bias.x, out_min, out_max));
117+ const ivec3 out_lpos = ivec3 (out_l, out_c, n / 4 );
118+ write_texel_lpos(t_out, out_lpos , op(sum + bias.x, out_min, out_max), out_axis_map );
129119 }
130120 }
131121}
0 commit comments