1111
1212#extension GL_EXT_control_flow_attributes : require
1313
14- struct InputWindow1D {
15- vec4[MAX_WINDOW_WIDTH] data;
16- int len;
17- };
18-
19- InputWindow1D initial_input_window() {
20- InputWindow1D input_window;
21- for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) {
22- input_window.data[i] = vec4(0);
23- }
24- input_window.len = 0;
25- return input_window;
26- }
27-
2814vec4 dequantize(const int packed_texel, const float scale, const int zp) {
2915 return vec4(unpack_int8x4(packed_texel) - zp) * scale;
3016}
@@ -49,109 +35,10 @@ bool in_bounds(
4935 return true;
5036}
5137
52- InputWindow1D load_input_window(
53- const int w_start,
54- const int w_end,
55- const int h,
56- const int c4,
57- const Conv2dBlockExtents block_extents,
58- const float input_scale,
59- const int input_zp,
60- const ivec4 input_zps) {
61- InputWindow1D input_window = initial_input_window();
62-
63- const int block_w_start = div_4(w_start);
64- const int block_w_end = div_4(w_end);
65-
66- int window_i = 0;
67- for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) {
68- ivec4 input_block = input_zps;
69-
70- if (in_bounds(block_w, h, c4, block_extents)) {
71- #ifdef PACKED_INT8_INPUT_BUFFER
72- const int buffer_idx =
73- h * block_extents.data_xz + block_w * block_extents.data.z + c4;
74- input_block = t_packed_int8_input[buffer_idx];
75- #else
76- input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0);
77- #endif
78- }
79-
80- const int loaded_w_start = mul_4(block_w);
81- for (int row = 0; row < 4; ++row) {
82- if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) {
83- input_window.data[window_i++] =
84- dequantize(input_block[row], input_scale, input_zp);
85- }
86- }
87- }
88- input_window.len = window_i;
89- return input_window;
90- }
91-
92- struct WeightRow {
93- vec4[MAX_KERNEL_WIDTH] data;
94- int len;
95- };
96-
97- WeightRow initial_weight_row() {
98- WeightRow weight_row;
99- for (int i = 0; i < MAX_KERNEL_WIDTH; ++i) {
100- weight_row.data[i] = vec4(0);
101- }
102- weight_row.len = 0;
103- return weight_row;
104- }
105-
106- WeightRow load_weight_row(
107- const int oc4,
108- const int ky,
109- const int OC4,
110- const int Kw,
111- const int Kw4,
112- const vec4 weight_scales) {
113- WeightRow weight_row = initial_weight_row();
114-
115- int k4 = ky * Kw4;
116- int row_idx = 0;
117- for (int w = 0; w < Kw; w += 4) {
118- #ifdef WEIGHT_BUFFER
119- const ivec4 weight_block = t_packed_int8_weight[k4 * OC4 + oc4];
120- #else
121- const ivec4 weight_block = texelFetch(
122- t_packed_int8_weight, ivec2(oc4, k4), 0);
123- #endif
124-
125- for (int row = 0; row < 4; ++row) {
126- if (w + row < Kw) {
127- weight_row.data[row_idx++] = dequantize(weight_block[row], weight_scales);
128- }
129- }
130- k4++;
131- }
132- weight_row.len = row_idx;
133- return weight_row;
134- }
135-
13638struct FPOutBlock {
13739 vec4[4] data;
13840};
13941
140- void perform_conv1d(
141- inout FPOutBlock out_block,
142- const InputWindow1D input_window,
143- const WeightRow weight_row) {
144- for (int out_w = 0; out_w < 4; ++out_w) {
145- [[unroll]] for (int kx = 0; kx < weight_row.len; ++kx) {
146- const int in_w = out_w * conv2d_params.stride.x;
147- out_block.data[out_w] = fma(
148- input_window.data[in_w + kx],
149- weight_row.data[kx],
150- out_block.data[out_w]);
151- }
152- }
153- }
154-
15542ivec4 quantize(
15643 const vec4 texel, const float inv_scale, const int zp) {
15744 vec4 quantized = round(texel * inv_scale) + zp;
@@ -168,6 +55,50 @@ ivec4 quantize_and_pack(
16855 return packed_block;
16956}
17057
58+ // Load a 4xint8 block of weights. Equivalent to unpacked_weights[kh][kw][c:c+4].
59+ int load_weight_1w4c(
60+ int kw, // w coordinate
61+ int kh, // h coordinate
62+ int oc4, // channel block
63+ int KW4, // kernel width / 4 (rounded up)
64+ int OC4 // out channels count / 4 (rounded up)
65+ ) {
66+
67+ // Find the packed block index. Weights are packed as 4W4C tiles.
68+ int kw4 = kw / 4; // W block
69+ int linear_idx = ((kh * KW4 + kw4) * OC4 + oc4) * 4;
70+ int block_x_offset = kw % 4;
71+ #ifdef WEIGHT_BUFFER
72+ return t_packed_int8_weight[linear_idx + block_x_offset];
73+ #else
74+ return texelFetch(t_packed_int8_weight, ivec2(oc4, kh * KW4 + kw4), 0)[block_x_offset];
75+ #endif
76+ }
77+
78+ // Load a 4xint8 block of inputs - channel c through c+3 (c = oc4*4) at
79+ // the given spatial location. Equivalent to unpacked_input[0][c:c+4][h][w].
80+ int load_input_1w4c(
81+ int w, // w coordinate
82+ int h, // h coordinate
83+ int oc4, // channel block
84+ int OC4, // out channels / 4 (rounded up)
85+ Conv2dBlockExtents block_extents
86+ ) {
87+ int block_w = w / 4;
88+
89+ if (in_bounds(block_w, h, oc4, block_extents) && w >= 0) {
90+ #ifdef PACKED_INT8_INPUT_BUFFER
91+ const int buffer_idx =
92+ (h * block_extents.data_xz + block_w * block_extents.data.z + oc4) * 4 + (w % 4);
93+ return t_packed_int8_input[buffer_idx];
94+ #else
95+ #error Unimplemented
96+ #endif
97+ } else {
98+ return pack_into_int32(ivec4(input_zp));
99+ }
100+ }
101+
171102#ifdef DEBUG_MODE
172103
173104void printInputWindow1D(const InputWindow1D input_window) {
0 commit comments