@@ -51,17 +51,19 @@ utils::uvec3 quantize_per_channel_local_wg_size(
5151
5252 const ValueRef input = args.at (1 ).refs .at (0 );
5353
54- utils::uvec3 local_wg_size = graph->create_local_wg_size (global_workgroup_size);
55-
56- // WORKAROUND: The CommandBuffer::dispatch function divides global_workgroup_size
57- // by local_workgroup_size to get the number of workgroups to dispatch.
58- // For per-channel quantization along the batch axis, we need to ensure that
59- // we dispatch the correct number of workgroups in the Z dimension to cover
60- // all batch-channel combinations.
54+ utils::uvec3 local_wg_size =
55+ graph->create_local_wg_size (global_workgroup_size);
56+
57+ // WORKAROUND: The CommandBuffer::dispatch function divides
58+ // global_workgroup_size by local_workgroup_size to get the number of
59+ // workgroups to dispatch. For per-channel quantization along the batch axis,
60+ // we need to ensure that we dispatch the correct number of workgroups in the
61+ // Z dimension to cover all batch-channel combinations.
6162 //
62- // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], local_wg_size[2])
63- // might reduce the number of workgroups dispatched. To ensure we dispatch
64- // global_workgroup_size[2] workgroups in the Z dimension, we set local_wg_size[2] = 1.
63+ // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2],
64+ // local_wg_size[2]) might reduce the number of workgroups dispatched. To
65+ // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension,
66+ // we set local_wg_size[2] = 1.
6567 const auto input_sizes = graph->sizes_of (input);
6668 if (global_workgroup_size[2 ] > 1 && input_sizes[3 ] > 0 ) {
6769 local_wg_size[2 ] = 1 ;
@@ -78,13 +80,23 @@ void add_quantize_per_tensor_node(
7880 const ValueRef& quant_min,
7981 const ValueRef& quant_max,
8082 const ValueRef& output) {
83+ const bool is_tensor_scale_zp =
84+ graph.val_is_tensor (scale) && graph.val_is_tensor (zero_point);
85+
8186 std::string kernel_name (" quantize_per_tensor" );
87+ if (is_tensor_scale_zp) {
88+ kernel_name += " _tensor" ;
89+ }
8290 add_storage_type_suffix (kernel_name, graph.storage_type_of (input));
8391 add_dtype_suffix (kernel_name, graph.dtype_of (input));
8492 add_dtype_suffix (kernel_name, graph.dtype_of (output));
8593
86- float scale_val = static_cast <float >(graph.get_double (scale));
87- int zero_point_val = static_cast <int >(graph.get_int (zero_point));
94+ float scale_val = 1.0 ;
95+ int zero_point_val = 0 ;
96+ if (!is_tensor_scale_zp) {
97+ scale_val = static_cast <float >(graph.get_double (scale));
98+ zero_point_val = static_cast <int >(graph.get_int (zero_point));
99+ }
88100 int quant_min_val = static_cast <int >(graph.get_int (quant_min));
89101 int quant_max_val = static_cast <int >(graph.get_int (quant_max));
90102
@@ -98,15 +110,17 @@ void add_quantize_per_tensor_node(
98110 graph.strides_ubo (input),
99111 graph.sizes_ubo (output),
100112 graph.strides_ubo (output)};
113+ } else {
114+ param_ubos = {
115+ graph.logical_limits_ubo (input), graph.logical_limits_ubo (output)};
116+ }
117+
118+ if (is_tensor_scale_zp) {
101119 push_constants = {
102- PushConstantDataInfo (&scale_val, sizeof (float )),
103- PushConstantDataInfo (&zero_point_val, sizeof (int )),
104120 PushConstantDataInfo (&quant_min_val, sizeof (int )),
105121 PushConstantDataInfo (&quant_max_val, sizeof (int )),
106122 };
107123 } else {
108- param_ubos = {
109- graph.logical_limits_ubo (input), graph.logical_limits_ubo (output)};
110124 push_constants = {
111125 PushConstantDataInfo (&scale_val, sizeof (float )),
112126 PushConstantDataInfo (&zero_point_val, sizeof (int )),
@@ -120,13 +134,20 @@ void add_quantize_per_tensor_node(
120134 graph.hashed_layout_of (input),
121135 };
122136
137+ std::vector<ArgGroup> inputs_and_outputs = {
138+ {output, vkapi::kWrite }, {input, vkapi::kRead }};
139+ if (is_tensor_scale_zp) {
140+ inputs_and_outputs.emplace_back (
141+ ArgGroup{{scale, zero_point}, vkapi::kRead });
142+ }
143+
123144 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
124145 graph,
125146 VK_KERNEL_FROM_STR (kernel_name),
126147 default_pick_global_wg_size,
127148 default_pick_local_wg_size,
128149 // Inputs and Outputs
129- {{output, vkapi:: kWrite }, {input, vkapi:: kRead }} ,
150+ inputs_and_outputs ,
130151 // Shader param buffers
131152 param_ubos,
132153 // Push Constants
@@ -241,8 +262,8 @@ void add_quantize_per_channel_node(
241262
242263 int num_channels;
243264 if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage (input)) {
244- // For batch dimension quantization in 4D tensors, pass the actual number of channels
245- // so the shader can correctly unfold the batch-channel folding
265+ // For batch dimension quantization in 4D tensors, pass the actual number of
266+ // channels so the shader can correctly unfold the batch-channel folding
246267 num_channels = static_cast <int >(input_sizes[1 ]); // Channel dimension
247268 } else {
248269 num_channels = static_cast <int >(input_sizes[axis_val]);
@@ -487,6 +508,9 @@ REGISTER_OPERATORS {
487508 VK_REGISTER_OP (
488509 quantized_decomposed.quantize_per_tensor .default ,
489510 quantize_per_tensor_impl);
511+ VK_REGISTER_OP (
512+ quantized_decomposed.quantize_per_tensor .tensor ,
513+ quantize_per_tensor_impl);
490514 VK_REGISTER_OP (
491515 quantized_decomposed.quantize_per_token .default , quantize_per_token_impl);
492516 VK_REGISTER_OP (
0 commit comments