@@ -85,8 +85,6 @@ void add_dequantize_per_tensor_node(
8585 add_dtype_suffix (kernel_name, graph.dtype_of (input));
8686 add_dtype_suffix (kernel_name, graph.dtype_of (output));
8787
88- float scale_val = static_cast <float >(graph.get_double (scale));
89- int zero_point_val = static_cast <int >(graph.get_int (zero_point));
9088 int quant_min_val = static_cast <int >(graph.get_int (quant_min));
9189 int quant_max_val = static_cast <int >(graph.get_int (quant_max));
9290
@@ -100,23 +98,16 @@ void add_dequantize_per_tensor_node(
10098 graph.strides_ubo (input),
10199 graph.sizes_ubo (output),
102100 graph.strides_ubo (output)};
103- push_constants = {
104- PushConstantDataInfo (&scale_val, sizeof (float )),
105- PushConstantDataInfo (&zero_point_val, sizeof (int )),
106- PushConstantDataInfo (&quant_min_val, sizeof (int )),
107- PushConstantDataInfo (&quant_max_val, sizeof (int )),
108- };
109101 } else {
110102 param_ubos = {
111103 graph.logical_limits_ubo (input), graph.logical_limits_ubo (output)};
112- push_constants = {
113- PushConstantDataInfo (&scale_val, sizeof (float )),
114- PushConstantDataInfo (&zero_point_val, sizeof (int )),
115- PushConstantDataInfo (&quant_min_val, sizeof (int )),
116- PushConstantDataInfo (&quant_max_val, sizeof (int )),
117- };
118104 }
119105
106+ push_constants = {
107+ PushConstantDataInfo (&quant_min_val, sizeof (int )),
108+ PushConstantDataInfo (&quant_max_val, sizeof (int )),
109+ };
110+
120111 vkapi::SpecVarList spec_vars = {
121112 graph.hashed_layout_of (output),
122113 graph.hashed_layout_of (input),
@@ -128,7 +119,9 @@ void add_dequantize_per_tensor_node(
128119 default_pick_global_wg_size,
129120 default_pick_local_wg_size,
130121 // Inputs and Outputs
131- {{output, vkapi::kWrite }, {input, vkapi::kRead }},
122+ {{output, vkapi::kWrite },
123+ {input, vkapi::kRead },
124+ {{scale, zero_point}, vkapi::kRead }},
132125 // Shader param buffers
133126 param_ubos,
134127 // Push Constants
@@ -517,7 +510,7 @@ void dequantize_per_channel_impl(
517510
518511REGISTER_OPERATORS {
519512 VK_REGISTER_OP (
520- quantized_decomposed.dequantize_per_tensor .default ,
513+ quantized_decomposed.dequantize_per_tensor .tensor ,
521514 dequantize_per_tensor_impl);
522515 VK_REGISTER_OP (
523516 quantized_decomposed.dequantize_per_token .default ,
0 commit comments