@@ -87,8 +87,6 @@ void add_quantize_per_tensor_node(
8787 add_dtype_suffix (kernel_name, graph.dtype_of (input));
8888 add_dtype_suffix (kernel_name, graph.dtype_of (output));
8989
90- float scale_val = static_cast <float >(graph.get_double (scale));
91- int zero_point_val = static_cast <int >(graph.get_int (zero_point));
9290 int quant_min_val = static_cast <int >(graph.get_int (quant_min));
9391 int quant_max_val = static_cast <int >(graph.get_int (quant_max));
9492
@@ -102,23 +100,16 @@ void add_quantize_per_tensor_node(
102100 graph.strides_ubo (input),
103101 graph.sizes_ubo (output),
104102 graph.strides_ubo (output)};
105- push_constants = {
106- PushConstantDataInfo (&scale_val, sizeof (float )),
107- PushConstantDataInfo (&zero_point_val, sizeof (int )),
108- PushConstantDataInfo (&quant_min_val, sizeof (int )),
109- PushConstantDataInfo (&quant_max_val, sizeof (int )),
110- };
111103 } else {
112104 param_ubos = {
113105 graph.logical_limits_ubo (input), graph.logical_limits_ubo (output)};
114- push_constants = {
115- PushConstantDataInfo (&scale_val, sizeof (float )),
116- PushConstantDataInfo (&zero_point_val, sizeof (int )),
117- PushConstantDataInfo (&quant_min_val, sizeof (int )),
118- PushConstantDataInfo (&quant_max_val, sizeof (int )),
119- };
120106 }
121107
108+ push_constants = {
109+ PushConstantDataInfo (&quant_min_val, sizeof (int )),
110+ PushConstantDataInfo (&quant_max_val, sizeof (int )),
111+ };
112+
122113 vkapi::SpecVarList spec_vars = {
123114 graph.hashed_layout_of (output),
124115 graph.hashed_layout_of (input),
@@ -130,7 +121,9 @@ void add_quantize_per_tensor_node(
130121 default_pick_global_wg_size,
131122 default_pick_local_wg_size,
132123 // Inputs and Outputs
133- {{output, vkapi::kWrite }, {input, vkapi::kRead }},
124+ {{output, vkapi::kWrite },
125+ {input, vkapi::kRead },
126+ {{scale, zero_point}, vkapi::kRead }},
134127 // Shader param buffers
135128 param_ubos,
136129 // Push Constants
@@ -489,7 +482,7 @@ void quantize_per_channel_impl(
489482
490483REGISTER_OPERATORS {
491484 VK_REGISTER_OP (
492- quantized_decomposed.quantize_per_tensor .default ,
485+ quantized_decomposed.quantize_per_tensor .tensor ,
493486 quantize_per_tensor_impl);
494487 VK_REGISTER_OP (
495488 quantized_decomposed.quantize_per_token .default , quantize_per_token_impl);
0 commit comments