@@ -41,7 +41,7 @@ std::tuple<int64_t, int64_t> get_quantized_input_num_blocks(
4141 return std::make_tuple (num_blocks_M, num_blocks_K);
4242}
4343
44- utils::uvec3 quant_pack_input_global_wg_size (
44+ utils::uvec3 quantize_and_pack_4h4w_global_wg_size (
4545 ComputeGraph* graph,
4646 const vkapi::ShaderInfo& shader,
4747 const std::vector<ArgGroup>& args,
@@ -57,7 +57,7 @@ utils::uvec3 quant_pack_input_global_wg_size(
5757 1u };
5858}
5959
60- vkapi::ShaderInfo pick_quantize_and_pack_input_with_sums_shader (
60+ vkapi::ShaderInfo pick_quantize_and_pack_4h4w_with_group_sums_shader (
6161 ComputeGraph* graph,
6262 const std::vector<ArgGroup>& args,
6363 const std::vector<ValueRef>& resize_args) {
@@ -67,7 +67,7 @@ vkapi::ShaderInfo pick_quantize_and_pack_input_with_sums_shader(
6767
6868 const int64_t group_size_val = graph->extract_scalar <int64_t >(group_size);
6969
70- std::string shader_name = " quantize_and_pack_linear_input_with_sums " ;
70+ std::string shader_name = " quantize_and_pack_4h4w_with_group_sums " ;
7171 if (group_size_val >= 128 ) {
7272 shader_name += " _o2w32" ;
7373 } else {
@@ -82,7 +82,7 @@ vkapi::ShaderInfo pick_quantize_and_pack_input_with_sums_shader(
8282 return VK_KERNEL_FROM_STR (shader_name);
8383}
8484
85- utils::uvec3 pick_quantize_and_pack_input_with_sums_global_wg_size (
85+ utils::uvec3 pick_quantize_and_pack_4h4w_with_group_sums_global_wg_size (
8686 ComputeGraph* graph,
8787 const vkapi::ShaderInfo& shader,
8888 const std::vector<ArgGroup>& args,
@@ -113,7 +113,7 @@ utils::uvec3 pick_quantize_and_pack_input_with_sums_global_wg_size(
113113 1u };
114114}
115115
116- utils::uvec3 pick_quantize_and_pack_input_with_sums_local_wg_size (
116+ utils::uvec3 pick_quantize_and_pack_4h4w_with_group_sums_local_wg_size (
117117 ComputeGraph* graph,
118118 const vkapi::ShaderInfo& shader,
119119 const utils::uvec3& global_workgroup_size,
@@ -144,7 +144,7 @@ utils::uvec3 pick_quantize_and_pack_input_with_sums_local_wg_size(
144144// Dispatch logic (Linear)
145145//
146146
147- void add_quantize_and_pack_linear_input_node (
147+ void add_quantize_and_pack_4h4w_node (
148148 ComputeGraph& graph,
149149 const QuantizationConfig& input_quant_config,
150150 const ValueRef fp_input,
@@ -164,7 +164,7 @@ void add_quantize_and_pack_linear_input_node(
164164 float inv_scale = 1 .0f / graph.extract_scalar <float >(input_scale_data);
165165 int32_t zp = graph.extract_scalar <int32_t >(input_zp_data);
166166
167- std::string shader_name = " quantize_and_pack_linear_input_per_tensor " ;
167+ std::string shader_name = " quantize_and_pack_4h4w_per_tensor " ;
168168 add_storage_type_suffix (shader_name, graph.storage_type_of (packed_int_input));
169169 add_storage_type_suffix (shader_name, graph.storage_type_of (fp_input));
170170 add_dtype_suffix (shader_name, graph.dtype_of (fp_input));
@@ -179,7 +179,7 @@ void add_quantize_and_pack_linear_input_node(
179179 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
180180 graph,
181181 VK_KERNEL_FROM_STR (shader_name),
182- quant_pack_input_global_wg_size ,
182+ quantize_and_pack_4h4w_global_wg_size ,
183183 default_pick_local_wg_size,
184184 // Inputs and Outputs
185185 {{packed_int_input, vkapi::kWrite }, {fp_input, vkapi::kRead }},
@@ -193,7 +193,7 @@ void add_quantize_and_pack_linear_input_node(
193193 {}));
194194}
195195
196- void add_quantize_and_pack_linear_input_with_sums_node (
196+ void add_quantize_and_pack_4h4w_with_group_sums_node (
197197 ComputeGraph& graph,
198198 const QuantizationConfig& input_quant_config,
199199 const ValueRef fp_input,
@@ -216,9 +216,9 @@ void add_quantize_and_pack_linear_input_with_sums_node(
216216
217217 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
218218 graph,
219- pick_quantize_and_pack_input_with_sums_shader ,
220- pick_quantize_and_pack_input_with_sums_global_wg_size ,
221- pick_quantize_and_pack_input_with_sums_local_wg_size ,
219+ pick_quantize_and_pack_4h4w_with_group_sums_shader ,
220+ pick_quantize_and_pack_4h4w_with_group_sums_global_wg_size ,
221+ pick_quantize_and_pack_4h4w_with_group_sums_local_wg_size ,
222222 // Inputs and Outputs
223223 {{{packed_int_input, int_input_sums}, vkapi::kWrite },
224224 {{fp_input, packed_input_scales, packed_input_zps}, vkapi::kRead }},
@@ -236,7 +236,7 @@ void add_quantize_and_pack_linear_input_with_sums_node(
236236// Dispatch utilities (Conv2d)
237237//
238238
239- utils::uvec3 pick_quantize_and_pack_conv2d_input_global_wg_size (
239+ utils::uvec3 pick_quantize_and_pack_4w4c_global_wg_size (
240240 ComputeGraph* graph,
241241 const vkapi::ShaderInfo& shader,
242242 const std::vector<ArgGroup>& args,
@@ -253,7 +253,7 @@ utils::uvec3 pick_quantize_and_pack_conv2d_input_global_wg_size(
253253 return {W4, H, C4};
254254}
255255
256- utils::uvec3 pick_unpack_and_dequantize_conv2d_output_global_wg_size (
256+ utils::uvec3 pick_unpack_4w4c_and_dequantize_global_wg_size (
257257 ComputeGraph* graph,
258258 const vkapi::ShaderInfo& shader,
259259 const std::vector<ArgGroup>& args,
@@ -274,7 +274,7 @@ utils::uvec3 pick_unpack_and_dequantize_conv2d_output_global_wg_size(
274274// Dispatch logic (Conv2d)
275275//
276276
277- void add_quantize_and_pack_q8ta_conv2d_input_node (
277+ void add_quantize_and_pack_4w4c_node (
278278 ComputeGraph& graph,
279279 const ValueRef fp_input,
280280 const ValueRef input_scale,
@@ -284,7 +284,7 @@ void add_quantize_and_pack_q8ta_conv2d_input_node(
284284 int32_t zp = graph.extract_scalar <int32_t >(input_zp);
285285
286286 // Get shader for quantized conv2d linear tiled
287- std::string kernel_name = " quantize_and_pack_q8ta_conv2d_input " ;
287+ std::string kernel_name = " quantize_and_pack_4w4c_per_tensor " ;
288288 add_storage_type_suffix (
289289 kernel_name, graph.storage_type_of (packed_int8_input));
290290 add_storage_type_suffix (kernel_name, graph.storage_type_of (fp_input));
@@ -302,7 +302,7 @@ void add_quantize_and_pack_q8ta_conv2d_input_node(
302302 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
303303 graph,
304304 VK_KERNEL_FROM_STR (kernel_name),
305- pick_quantize_and_pack_conv2d_input_global_wg_size ,
305+ pick_quantize_and_pack_4w4c_global_wg_size ,
306306 pick_wc_square_wg_size,
307307 // Inputs and Outputs
308308 {{packed_int8_input, vkapi::kWrite }, {fp_input, vkapi::kRead }},
@@ -318,7 +318,7 @@ void add_quantize_and_pack_q8ta_conv2d_input_node(
318318 nullptr ));
319319}
320320
321- void add_unpack_and_dequantize_q8ta_conv2d_output_node (
321+ void add_unpack_4w4c_and_dequantize_node (
322322 ComputeGraph& graph,
323323 const ValueRef packed_int8_output,
324324 const ValueRef output_scale,
@@ -328,7 +328,7 @@ void add_unpack_and_dequantize_q8ta_conv2d_output_node(
328328 int32_t zp = graph.extract_scalar <int32_t >(output_zp);
329329
330330 // Get shader for quantized conv2d linear tiled
331- std::string kernel_name = " unpack_and_dequantize_q8ta_conv2d_output " ;
331+ std::string kernel_name = " unpack_4w4c_and_dequantize_per_tensor " ;
332332 add_storage_type_suffix (kernel_name, graph.storage_type_of (fp_output));
333333 add_storage_type_suffix (
334334 kernel_name, graph.storage_type_of (packed_int8_output));
@@ -346,7 +346,7 @@ void add_unpack_and_dequantize_q8ta_conv2d_output_node(
346346 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
347347 graph,
348348 VK_KERNEL_FROM_STR (kernel_name),
349- pick_unpack_and_dequantize_conv2d_output_global_wg_size ,
349+ pick_unpack_4w4c_and_dequantize_global_wg_size ,
350350 default_pick_local_wg_size,
351351 // Inputs and Outputs
352352 {{fp_output, vkapi::kWrite }, {packed_int8_output, vkapi::kRead }},
@@ -375,7 +375,7 @@ void quantize_q8ta_for_conv2d(
375375 const ValueRef zero_point = args.at (idx++);
376376 const ValueRef packed_int8_input = args.at (idx++);
377377
378- add_quantize_and_pack_q8ta_conv2d_input_node (
378+ add_quantize_and_pack_4w4c_node (
379379 graph, fp_input, scale, zero_point, packed_int8_input);
380380}
381381
@@ -388,7 +388,7 @@ void dequantize_q8to_from_conv2d(
388388 const ValueRef zero_point = args.at (idx++);
389389 const ValueRef fp_output = args.at (idx++);
390390
391- add_unpack_and_dequantize_q8ta_conv2d_output_node (
391+ add_unpack_4w4c_and_dequantize_node (
392392 graph, packed_int8_output, scale, zero_point, fp_output);
393393}
394394
@@ -408,10 +408,10 @@ void qdq8ta_conv2d_input(
408408 utils::kBuffer ,
409409 utils::kPackedInt8_4W4C );
410410
411- add_quantize_and_pack_q8ta_conv2d_input_node (
411+ add_quantize_and_pack_4w4c_node (
412412 graph, fp_input, scale, zero_point, packed_int8_input);
413413
414- add_unpack_and_dequantize_q8ta_conv2d_output_node (
414+ add_unpack_4w4c_and_dequantize_node (
415415 graph, packed_int8_input, scale, zero_point, fp_output);
416416}
417417
0 commit comments