1313#define IN_T ${buffer_scalar_type(IN_DTYPE)}
1414#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
1515
16+ #define ${MODE}
17+
1618${define_active_storage_type("buffer ")}
1719${define_required_extensions(IN_DTYPE)}
1820${define_required_extensions(OUT_DTYPE)}
1921
2022layout (std430) buffer ;
2123
24+ #include "indexing_utils.h"
25+
2226${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer ")}
2327${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer ")}
2428
@@ -29,7 +33,7 @@ $if MODE == "per_tensor":
2933 int quant_min;
3034 int quant_max;
3135 };
32- $else :
36+ $if MODE == "per_token" :
3337 ${layout_declare_tensor(B, "r", "t_scale", "float ", "buffer ")}
3438 ${layout_declare_tensor(B, "r", "t_zero_point", "int ", "buffer ")}
3539
@@ -39,87 +43,77 @@ $else:
3943 int quant_max;
4044 };
4145
46+ ${layout_declare_ubo(B, "int ", "out_numel")}
4247${layout_declare_ubo(B, "ivec4 ", "t_in_sizes")}
4348${layout_declare_ubo(B, "ivec4 ", "t_in_strides")}
4449${layout_declare_ubo(B, "ivec4 ", "t_out_sizes")}
4550${layout_declare_ubo(B, "ivec4 ", "t_out_strides")}
4651
47- #include "indexing_utils.h"
52+ ${layout_declare_spec_const(C, "int ", "out_layout", "DEFAULT_LAYOUT")}
53+ ${layout_declare_spec_const(C, "int ", "in_layout", "DEFAULT_LAYOUT")}
54+
4855#include "quantize.glslh"
4956
5057layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
5158
52- void main() {
53- $if MODE == "per_tensor":
54- const ivec4 pos = ivec4 (
55- gl_GlobalInvocationID.x,
56- gl_GlobalInvocationID.y,
57- gl_GlobalInvocationID.z,
58- 0 );
59-
60- const int t_in_idx = tidx_to_bufi(pos, t_in_strides);
61- const int t_out_idx = tidx_to_bufi(pos, t_out_strides);
62-
63- IN_T value = t_in[t_in_idx];
64- OUT_T qvalue;
65-
66- qvalue = quantize_val(value, scale, zero_point);
59+ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
60+ const lowp ivec4 in_dim_order = unhash_dim_order(in_layout);
6761
68- t_out[t_out_idx] = qvalue;
62+ #ifdef per_tensor
6963
70- $if MODE == "per_token":
71- const ivec4 pos = ivec4 (
72- gl_GlobalInvocationID.x,
73- gl_GlobalInvocationID.y,
74- gl_GlobalInvocationID.z,
75- 0 );
76-
77- const int t_in_idx = tidx_to_bufi(pos, t_in_strides);
78- const int t_out_idx = tidx_to_bufi(pos, t_out_strides);
64+ void quantize_per_tensor() {
65+ const int out_bufi = int (gl_GlobalInvocationID.x);
7966
80- // Skip if out of bounds
81- if (t_in_idx >= t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w) {
67+ if (out_bufi >= out_numel) {
8268 return ;
8369 }
8470
85- IN_T value = t_in[t_in_idx];
86- OUT_T qvalue;
71+ const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
72+ const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
73+
74+ IN_T value = t_in[in_bufi];
75+ OUT_T qvalue = quantize_val(value, scale, zero_point);
8776
88- // Calculate logical position from linear index and strides
89- ivec4 logical_pos;
90- int remaining = t_in_idx;
77+ t_out[out_bufi] = qvalue;
78+ }
9179
92- logical_pos.x = remaining % t_in_sizes.x;
93- remaining /= t_in_sizes.x;
80+ #else
9481
95- logical_pos.y = remaining % t_in_sizes.y;
96- remaining /= t_in_sizes.y ;
82+ void quantize_per_token() {
83+ const int out_bufi = int (gl_GlobalInvocationID.x) ;
9784
98- logical_pos.z = remaining % t_in_sizes.z;
99- remaining /= t_in_sizes.z;
85+ if (out_bufi >= out_numel) {
86+ return ;
87+ }
10088
101- logical_pos.w = remaining;
89+ const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
90+ const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
91+
92+ IN_T value = t_in[in_bufi];
10293
103- // Calculate token index based on logical position
10494 int token_idx = 0 ;
10595
106- // Check dimensions to determine how to calculate token_idx
107- if (t_in_sizes.w > 1 ) {
96+ if (t_out_sizes.w > 1 ) {
10897 // 4D tensor
109- token_idx = logical_pos .w * (t_in_sizes .z * t_in_sizes .y) + logical_pos .z * t_in_sizes .y + logical_pos .y;
110- } else if (t_in_sizes .z > 1 ) {
98+ token_idx = out_tidx .w * (t_out_sizes .z * t_out_sizes .y) + out_tidx .z * t_out_sizes .y + out_tidx .y;
99+ } else if (t_out_sizes .z > 1 ) {
111100 // 3D tensor
112- token_idx = logical_pos .z * t_in_sizes .y + logical_pos .y;
113- } else if (t_in_sizes .y > 1 ) {
101+ token_idx = out_tidx .z * t_out_sizes .y + out_tidx .y;
102+ } else if (t_out_sizes .y > 1 ) {
114103 // 2D tensor
115- token_idx = logical_pos .y;
104+ token_idx = out_tidx .y;
116105 }
117106 // For 1D tensor, token_idx remains 0
118107
119- // Make sure token_idx is within bounds
120108 token_idx = min (token_idx, num_tokens - 1 );
121109
122- qvalue = quantize_val(value, t_scale[token_idx], t_zero_point[token_idx]);
110+ OUT_T qvalue = quantize_val(value, t_scale[token_idx], t_zero_point[token_idx]);
111+
112+ t_out[out_bufi] = qvalue;
113+ }
123114
124- t_out[t_out_idx] = qvalue;
115+ #endif
116+
117+ void main() {
118+ quantize_${MODE}();
125119}
0 commit comments