1212#define VEC4_T ${texel_load_type(DTYPE, "buffer ")}
1313#define T ${texel_load_component_type(DTYPE, "buffer ")}
1414
15+ // #define DEBUG_MODE 1
16+
17+ #extension GL_EXT_control_flow_attributes : require
18+
1519$if IO_STORAGE == "buffer ":
1620 #define PACKED_INT8_OUTPUT_BUFFER
1721 #define PACKED_INT8_INPUT_BUFFER
@@ -27,9 +31,9 @@ layout(std430) buffer;
2731
2832#include "conv2d_common.glslh"
2933
30- ${layout_declare_tensor(B, "w", "t_packed_int8_output", "int ", IO_STORAGE, is_scalar_array= False )}
31- ${layout_declare_tensor(B, "r", "t_packed_int8_input", "int ", IO_STORAGE, is_scalar_array= False )}
32- ${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int ", WEIGHT_STORAGE, is_scalar_array= False )}
34+ ${layout_declare_tensor(B, "w", "t_packed_int8_output", "int ", IO_STORAGE, is_scalar_array= True )}
35+ ${layout_declare_tensor(B, "r", "t_packed_int8_input", "int ", IO_STORAGE, is_scalar_array= True )}
36+ ${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int ", WEIGHT_STORAGE, is_scalar_array= True )}
3337${layout_declare_tensor(B, "r", "t_weight_sums", "int ", "buffer ", is_scalar_array= False)}
3438${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer ", is_scalar_array= False)}
3539${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer ", is_scalar_array= False)}
@@ -51,71 +55,183 @@ ${layout_declare_spec_const(C, "int", "apply_bias", "1")}
5155
5256#include "conv2d_dw_q8_utils.glslh"
5357
58+ // Load a 4xint8 block of weights - channel c through c+3 (c = oc4/4).
59+ // Equivalent to unpacked_weights[ky][kx][c:c+4].
60+ int load_weight(
61+ int kx, // w coordinate to load
62+ int ky, // h coordinate to load
63+ int oc4, // channel block to load
64+ int kw4, // kernel width / 4 (rounded up)
65+ int OC4 // out channels / 4 (rounded up)
66+ ) {
67+
68+ // Find the packed block index.
69+ int kx4 = kx / 4 ; // W block
70+ // Index into the packed weights for a 4W4C block.
71+ int linear_idx = ((ky * kw4 + kx4) * OC4 + oc4) * 4 ;
72+ int block_x_offset = kx % 4 ;
73+ return t_packed_int8_weight[linear_idx + block_x_offset];
74+ }
75+
76+ // Load a 4xint8 block of inputs - channel c through c+3 (c = ic4/4) at
77+ // the given spatial location.
78+ int load_input(
79+ int x, // w coordinate
80+ int y, // h coordinate
81+ int oc4, // channel block
82+ int OC4, // out channels / 4 (rounded up)
83+ Conv2dBlockExtents block_extents
84+ ) {
85+ int block_w = x / 4 ;
86+
87+ if (in_bounds(block_w, y, oc4, block_extents) && x >= 0 ) {
88+ #ifdef PACKED_INT8_INPUT_BUFFER
89+ const int buffer_idx =
90+ (y * block_extents.data_xz + block_w * block_extents.data.z + oc4) * 4 + (x % 4 );
91+ return t_packed_int8_input[buffer_idx];
92+ #else
93+ #error Unimplemented
94+ // input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, oc4), 0);
95+ #endif
96+ } else {
97+ return pack_into_int32(ivec4 (input_zp));
98+ }
99+ }
100+
54101void main() {
55102 const int tid = int (gl_GlobalInvocationID.x);
56103 Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes);
104+ const int subtile = int (gl_GlobalInvocationID.y); // x offset within a block
57105
58106 Conv2dBlockIndex out_block_idx = linear_idx_to_block_idx(
59107 tid, out_block_extents);
60108
61- if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) {
109+ if (block_idx_out_of_bounds(out_block_idx, out_block_extents) || subtile >= 2 ) {
62110 return ;
63111 }
64112
65- const int out_w = mul_4(out_block_idx.data.x);
66- const int w_start =
67- (out_w * conv2d_params.stride.x) - conv2d_params.padding.x;
68- const int w_end = ((out_w + 3 ) * conv2d_params.stride.x) -
69- conv2d_params.padding.x +
70- (conv2d_params.kernel_size.x - 1 ) * conv2d_params.dilation.x;
113+ const int out_h = out_block_idx.data.y;
114+ const int out_w = mul_4(out_block_idx.data.x) + subtile * 2 ;
71115
72- Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes );
116+ // debugPrintfEXT("tid: %d, subtile: %d\\nComputing output %d %d\\n", tid, subtile, out_h, out_w );
73117
74- const ivec4 input_zps = ivec4 (pack_into_int32(ivec4 (input_zp)));
75- const vec4 weight_scales = vec4 (t_weight_scales[out_block_idx.data.z]);
118+ Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);
76119
77120 const int Kw4 = div_up_4(conv2d_params.kernel_size.x);
78121
79- FPOutBlock out_block;
122+ // Compute 4 channels for 4 output elements.
123+ ivec4 acc0 = ivec4 (0 );
124+ ivec4 acc1 = ivec4 (0 );
125+ ivec4 acc2 = ivec4 (0 );
126+ ivec4 acc3 = ivec4 (0 );
127+
80128 for (int ky = 0 ; ky < conv2d_params.kernel_size.y; ky++ ) {
81- const int out_h = out_block_idx.data.y;
82129 const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y +
83130 ky * conv2d_params.dilation.y;
84131
85- InputWindow1D input_window = load_input_window(
86- w_start,
87- w_end,
88- h,
89- out_block_idx.data.z,
90- in_block_extents,
91- input_scale,
92- input_zp,
93- input_zps);
94-
95- WeightRow weight_row = load_weight_row(
96- out_block_idx.data.z,
97- ky,
98- out_block_extents.data.z,
99- conv2d_params.kernel_size.x,
100- Kw4,
101- weight_scales);
102-
103- perform_conv1d(out_block, input_window, weight_row);
132+ for (int kx = 0 ; kx < conv2d_params.kernel_size.x; kx++ ) {
133+ const int w = out_w * conv2d_params.stride.x - conv2d_params.padding.x +
134+ kx * conv2d_params.dilation.x;
135+
136+ // Load, unpack, and dequantize weights.
137+ const int packed_weight_4c = load_weight(
138+ kx,
139+ ky,
140+ out_block_idx.data.z,
141+ Kw4,
142+ out_block_extents.data.z
143+ );
144+
145+ const ivec4 weight_4c = unpack_int8x4(packed_weight_4c);
146+
147+ // Load, unpack, and dequantize inputs.
148+ int packed_input0 = load_input(
149+ w,
150+ h,
151+ out_block_idx.data.z,
152+ out_block_extents.data.z,
153+ in_block_extents);
154+
155+ ivec4 input0 = unpack_int8x4(packed_input0);
156+ acc0 += weight_4c * input0;
157+
158+ int packed_input1 = load_input(
159+ w + conv2d_params.stride.x,
160+ h,
161+ out_block_idx.data.z,
162+ out_block_extents.data.z,
163+ in_block_extents);
164+
165+ ivec4 input1 = unpack_int8x4(packed_input1);
166+ acc1 += weight_4c * input1;
167+
168+ int packed_input2 = load_input(
169+ w + conv2d_params.stride.x * 2 ,
170+ h,
171+ out_block_idx.data.z,
172+ out_block_extents.data.z,
173+ in_block_extents);
174+
175+ ivec4 input2 = unpack_int8x4(packed_input2);
176+ acc2 += weight_4c * input2;
177+
178+ int packed_input3 = load_input(
179+ w + conv2d_params.stride.x * 3 ,
180+ h,
181+ out_block_idx.data.z,
182+ out_block_extents.data.z,
183+ in_block_extents);
184+
185+ ivec4 input3 = unpack_int8x4(packed_input3);
186+ acc3 += weight_4c * input3;
187+ }
104188 }
105189
190+ // Apply input zero point as weight_sum * input_zp.
191+ vec4 weight_sums = vec4 (t_weight_sums[out_block_idx.data.z]);
192+ const vec4 weight_scales = vec4 (t_weight_scales[out_block_idx.data.z]);
193+
194+ vec4 facc0 = vec4 (acc0);
195+ facc0 -= weight_sums * input_zp;
196+ facc0 *= weight_scales * input_scale;
197+
198+ vec4 facc1 = vec4 (acc1);
199+ facc1 -= weight_sums * input_zp;
200+ facc1 *= weight_scales * input_scale;
201+
202+ vec4 facc2 = vec4 (acc2);
203+ facc2 -= weight_sums * input_zp;
204+ facc2 *= weight_scales * input_scale;
205+
206+ vec4 facc3 = vec4 (acc3);
207+ facc3 -= weight_sums * input_zp;
208+ facc3 *= weight_scales * input_scale;
209+
106210 if (apply_bias > 0 ) {
107211 const vec4 bias = vec4 (t_bias[out_block_idx.data.z]);
108- for (int row = 0 ; row < 4 ; row++ ) {
109- out_block.data[row] += bias;
110- }
212+ facc0 += bias;
213+ facc1 += bias;
214+ facc2 += bias;
215+ facc3 += bias;
111216 }
112217
113- const ivec4 packed_out_block = quantize_and_pack(
114- out_block, output_inv_scale, output_zp);
218+ const ivec4 quantized_out0 = clamp (ivec4 (round(facc0 * output_inv_scale) + output_zp), - 128 , 127 );
219+ const ivec4 quantized_out1 = clamp (ivec4 (round(facc1 * output_inv_scale) + output_zp), - 128 , 127 );
220+ const ivec4 quantized_out2 = clamp (ivec4 (round(facc2 * output_inv_scale) + output_zp), - 128 , 127 );
221+ const ivec4 quantized_out3 = clamp (ivec4 (round(facc3 * output_inv_scale) + output_zp), - 128 , 127 );
222+
223+ const int packed_out_subblock0 = pack_into_int32(quantized_out0);
224+ const int packed_out_subblock1 = pack_into_int32(quantized_out1);
225+ const int packed_out_subblock2 = pack_into_int32(quantized_out2);
226+ const int packed_out_subblock3 = pack_into_int32(quantized_out3);
115227
116228#ifdef PACKED_INT8_OUTPUT_BUFFER
117- t_packed_int8_output[tid] = packed_out_block;
229+ t_packed_int8_output[tid * 4 ] = packed_out_subblock0;
230+ t_packed_int8_output[tid * 4 + 1 ] = packed_out_subblock1;
231+ t_packed_int8_output[tid * 4 + 2 ] = packed_out_subblock2;
232+ t_packed_int8_output[tid * 4 + 3 ] = packed_out_subblock3;
118233#else
119- imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block);
234+ // imageStore(t_packed_int8_output, out_block_idx.data, packed_out_block);
235+ #error Unimplemented
120236#endif
121237}
0 commit comments