@@ -306,10 +306,12 @@ void choose_qparams_tensor_impl(
306306 graph.dtype_of (input) == vkapi::kHalf ||
307307 graph.dtype_of (input) == vkapi::kDouble );
308308
309- // Verify output types - only accept Vulkan-supported types
310- // The Vulkan backend only supports float32 and int32, not float64/int64
309+ // Verify output types - accept both int32 and float32 for zero_point
310+ // TorchAO may use float32 for zero_point in some cases
311311 VK_CHECK_COND (graph.dtype_of (scale_out) == vkapi::kFloat );
312- VK_CHECK_COND (graph.dtype_of (zero_point_out) == vkapi::kInt );
312+ VK_CHECK_COND (
313+ graph.dtype_of (zero_point_out) == vkapi::kInt ||
314+ graph.dtype_of (zero_point_out) == vkapi::kFloat );
313315
314316 // Check that texture storage is width packed
315317 if (!graph.is_buffer_storage (input)) {
@@ -352,21 +354,96 @@ void choose_qparams_per_token_asymmetric_impl(
352354 graph.dtype_of (input) == vkapi::kHalf ||
353355 graph.dtype_of (input) == vkapi::kDouble );
354356
355- // Verify output types - only accept Vulkan-supported types
356- // The Vulkan backend only supports float32 and int32, not float64/int64
357+ // Verify output types - accept both int32 and float32 for zero_point
358+ // TorchAO may use float32 for zero_point in some cases
357359 VK_CHECK_COND (graph.dtype_of (scale_out) == vkapi::kFloat );
358- VK_CHECK_COND (graph.dtype_of (zero_point_out) == vkapi::kInt );
360+ VK_CHECK_COND (
361+ graph.dtype_of (zero_point_out) == vkapi::kInt ||
362+ graph.dtype_of (zero_point_out) == vkapi::kFloat );
359363
360364 add_choose_qparams_per_token_asymmetric_node (
361365 graph, input, scale_out, zero_point_out);
362366}
363367
368+ void choose_qparams_affine_impl (
369+ ComputeGraph& graph,
370+ const std::vector<ValueRef>& args) {
371+ int arg_idx = 0 ;
372+ const ValueRef input = args[arg_idx++];
373+ const ValueRef mapping_type = args[arg_idx++]; // str - ignored for per-tensor
374+ const ValueRef block_size =
375+ args[arg_idx++]; // SymInt[] - ignored for per-tensor
376+ const ValueRef target_dtype = args[arg_idx++];
377+ const ValueRef quant_min = args[arg_idx++];
378+ const ValueRef quant_max = args[arg_idx++];
379+ const ValueRef eps = args[arg_idx++];
380+ const ValueRef scale_dtype = args[arg_idx++];
381+ const ValueRef zero_point_dtype = args[arg_idx++];
382+ const ValueRef out_tuple_ref = args[arg_idx++];
383+
384+ // Suppress unused variable warnings
385+ (void )mapping_type;
386+ (void )target_dtype;
387+ (void )scale_dtype;
388+ (void )zero_point_dtype;
389+
390+ ValueRef scale_out = kDummyValueRef ;
391+ ValueRef zero_point_out = kDummyValueRef ;
392+
393+ {
394+ const ValueListPtr out_tuple = graph.get_value_list (out_tuple_ref);
395+ scale_out = out_tuple->at (0 );
396+ zero_point_out = out_tuple->at (1 );
397+ }
398+
399+ // Check tensor types
400+ VK_CHECK_COND (graph.val_is_tensor (input));
401+ VK_CHECK_COND (graph.val_is_tensor (scale_out));
402+ VK_CHECK_COND (graph.val_is_tensor (zero_point_out));
403+
404+ // Verify input is a floating point type
405+ VK_CHECK_COND (
406+ graph.dtype_of (input) == vkapi::kFloat ||
407+ graph.dtype_of (input) == vkapi::kHalf ||
408+ graph.dtype_of (input) == vkapi::kDouble );
409+
410+ // Verify output types - accept both int32 and float32 for zero_point
411+ // TorchAO may use float32 for zero_point in some cases
412+ VK_CHECK_COND (graph.dtype_of (scale_out) == vkapi::kFloat );
413+ VK_CHECK_COND (
414+ graph.dtype_of (zero_point_out) == vkapi::kInt ||
415+ graph.dtype_of (zero_point_out) == vkapi::kFloat );
416+
417+ // Check if this is per-tensor quantization (only supported granularity)
418+ // block_size should equal input tensor dimensions for per-tensor quantization
419+ const auto input_sizes = graph.sizes_of (input);
420+ const auto block_size_list = graph.get_int_list (block_size);
421+ VK_CHECK_COND (block_size_list->size () == input_sizes.size ());
422+ for (size_t i = 0 ; i < input_sizes.size (); i++) {
423+ VK_CHECK_COND ((*block_size_list)[i] == input_sizes[i]);
424+ }
425+
426+ // Check that texture storage is width packed
427+ if (!graph.is_buffer_storage (input)) {
428+ VK_CHECK_COND (graph.packed_dim_of (input) == WHCN::kWidthDim );
429+ }
430+
431+ // Default to per-tensor quantization parameter calculation for TorchAO affine
432+ // ops
433+ add_choose_qparams_tensor_node (
434+ graph, input, quant_min, quant_max, eps, scale_out, zero_point_out);
435+ }
436+
364437REGISTER_OPERATORS {
365438 VK_REGISTER_OP (
366439 quantized_decomposed.choose_qparams .tensor , choose_qparams_tensor_impl);
367440 VK_REGISTER_OP (
368441 quantized_decomposed.choose_qparams_per_token_asymmetric .default ,
369442 choose_qparams_per_token_asymmetric_impl);
443+
444+ // TorchAO affine choose_qparams operators
445+ VK_REGISTER_OP (
446+ torchao.choose_qparams_affine .default , choose_qparams_affine_impl);
370447}
371448
372449} // namespace vkcompute
0 commit comments