@@ -311,6 +311,111 @@ std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl(
311311 return std::make_tuple (scale_out, zero_point_out);
312312}
313313
314+ /*
315+ * Reference implementation of choose_qparams_per_token_asymmetric
316+ */
317+ std::tuple<at::Tensor, at::Tensor>
318+ choose_qparams_per_token_asymmetric_reference_impl (
319+ const at::Tensor& input,
320+ at::ScalarType dtype) {
321+ // For per-token quantization, we need to compute scale and zero_point for
322+ // each token
323+ int64_t quant_min = -128 ;
324+ int64_t quant_max = 127 ;
325+
326+ // Calculate output sizes
327+ std::vector<int64_t > output_sizes;
328+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
329+ output_sizes.push_back (input.size (i));
330+ }
331+ output_sizes.push_back (1 );
332+
333+ // Create output tensors
334+ at::Tensor scale_out =
335+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kDouble ));
336+ at::Tensor zero_point_out =
337+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kLong ));
338+
339+ // Calculate number of tokens
340+ int64_t num_tokens = 1 ;
341+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
342+ num_tokens *= input.size (i);
343+ }
344+
345+ // Reshape input to [num_tokens, last_dim]
346+ at::Tensor reshaped_input = input.reshape ({num_tokens, input.size (-1 )});
347+
348+ // Process each token
349+ for (int64_t token_idx = 0 ; token_idx < num_tokens; token_idx++) {
350+ at::Tensor token = reshaped_input[token_idx];
351+
352+ // Find min and max values for this token
353+ float min_val = token.min ().item <float >();
354+ float max_val = token.max ().item <float >();
355+
356+ // Extend the [min, max] interval to ensure it contains 0
357+ min_val = std::min (min_val, 0 .f );
358+ max_val = std::max (max_val, 0 .f );
359+
360+ // Calculate scale
361+ double scale =
362+ (static_cast <double >(max_val) - min_val) / (quant_max - quant_min);
363+
364+ // Handle small scale
365+ constexpr float SMALL_SCALE_THRESHOLD = 6 .1e-5f ;
366+ if (float (scale) == 0 .0f || std::isinf (1 .0f / float (scale))) {
367+ scale = 0.1 ;
368+ }
369+
370+ if (scale < SMALL_SCALE_THRESHOLD) {
371+ float org_scale = scale;
372+ scale = SMALL_SCALE_THRESHOLD;
373+ // Adjust min and max based on new scale
374+ if (min_val == 0 .0f ) {
375+ max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
376+ } else if (max_val == 0 .0f ) {
377+ min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
378+ } else {
379+ float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
380+ min_val *= amplifier;
381+ max_val *= amplifier;
382+ }
383+ }
384+
385+ // Calculate zero point
386+ double zero_point_from_min =
387+ quant_min - min_val / static_cast <double >(scale);
388+ double zero_point_from_max =
389+ quant_max - max_val / static_cast <double >(scale);
390+ double zero_point_from_min_error =
391+ std::abs (quant_min) - std::abs (min_val / static_cast <double >(scale));
392+ double zero_point_from_max_error =
393+ std::abs (quant_max) - std::abs (max_val / static_cast <double >(scale));
394+ double initial_zero_point =
395+ zero_point_from_min_error < zero_point_from_max_error
396+ ? zero_point_from_min
397+ : zero_point_from_max;
398+
399+ // Nudge zero point to be an integer
400+ int64_t nudged_zero_point = 0 ;
401+ if (initial_zero_point < quant_min) {
402+ nudged_zero_point = quant_min;
403+ } else if (initial_zero_point > quant_max) {
404+ nudged_zero_point = quant_max;
405+ } else {
406+ nudged_zero_point =
407+ std::nearbyint (static_cast <float >(initial_zero_point));
408+ }
409+
410+ // Set output values for this token - use index_put_ for safety
411+ scale_out.view ({num_tokens, 1 }).index_put_ ({token_idx, 0 }, scale);
412+ zero_point_out.view ({num_tokens, 1 })
413+ .index_put_ ({token_idx, 0 }, nudged_zero_point);
414+ }
415+
416+ return std::make_tuple (scale_out, zero_point_out);
417+ }
418+
314419// Forward declaration of implementation functions
315420void test_vulkan_choose_qparams_tensor_impl (
316421 const std::vector<int >& input_sizes,
@@ -320,6 +425,12 @@ void test_vulkan_choose_qparams_tensor_impl(
320425 const vkcompute::utils::StorageType in_storage,
321426 const vkcompute::utils::StorageType out_storage);
322427
428+ void test_vulkan_choose_qparams_per_token_asymmetric_impl (
429+ const std::vector<int >& input_sizes,
430+ at::ScalarType dtype,
431+ const vkcompute::utils::StorageType in_storage,
432+ const vkcompute::utils::StorageType out_storage);
433+
323434// Wrapper function to test both buffer and texture storage types
324435void test_vulkan_choose_qparams_tensor (
325436 const std::vector<int >& input_sizes,
@@ -345,6 +456,22 @@ void test_vulkan_choose_qparams_tensor(
345456 vkcompute::utils::kTexture3D );
346457}
347458
459+ // Wrapper function to test both buffer and texture storage types
460+ void test_vulkan_choose_qparams_per_token_asymmetric (
461+ const std::vector<int >& input_sizes,
462+ at::ScalarType dtype) {
463+ // Test with buffer storage
464+ test_vulkan_choose_qparams_per_token_asymmetric_impl (
465+ input_sizes, dtype, vkcompute::utils::kBuffer , vkcompute::utils::kBuffer );
466+
467+ // Test with texture storage
468+ test_vulkan_choose_qparams_per_token_asymmetric_impl (
469+ input_sizes,
470+ dtype,
471+ vkcompute::utils::kTexture3D ,
472+ vkcompute::utils::kTexture3D );
473+ }
474+
348475void test_reference_choose_qparams_tensor (
349476 const std::vector<int >& input_sizes,
350477 int64_t quant_min,
@@ -506,3 +633,161 @@ TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
506633 127 , // quant_max
507634 at::kChar );
508635}
636+
637+ void test_reference_choose_qparams_per_token_asymmetric (
638+ const std::vector<int >& input_sizes,
639+ at::ScalarType dtype) {
640+ std::vector<int64_t > input_sizes_int64 (
641+ input_sizes.begin (), input_sizes.end ());
642+ at::Tensor input =
643+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
644+
645+ // Get reference output
646+ auto [reference_scale, reference_zero_point] =
647+ choose_qparams_per_token_asymmetric_reference_impl (input, dtype);
648+
649+ // Get implementation output
650+ auto [impl_scale, impl_zero_point] =
651+ torch::executor::native::choose_qparams_per_token_asymmetric_aten (
652+ input, dtype);
653+
654+ // Compare outputs
655+ const bool scale_correct = at::allclose (reference_scale, impl_scale);
656+ const bool zero_point_correct =
657+ at::equal (reference_zero_point, impl_zero_point);
658+
659+ if (!scale_correct || !zero_point_correct) {
660+ std::cout << " \n "
661+ << " Failed with parameters: " << std::endl;
662+
663+ std::cout << " input:" << std::endl;
664+ std::cout << input << std::endl;
665+ std::cout << " reference scale:" << std::endl;
666+ std::cout << reference_scale << std::endl;
667+ std::cout << " implementation scale:" << std::endl;
668+ std::cout << impl_scale << std::endl;
669+ std::cout << " reference zero_point:" << std::endl;
670+ std::cout << reference_zero_point << std::endl;
671+ std::cout << " implementation zero_point:" << std::endl;
672+ std::cout << impl_zero_point << std::endl;
673+ }
674+
675+ ASSERT_TRUE (scale_correct && zero_point_correct);
676+ }
677+
678+ void test_vulkan_choose_qparams_per_token_asymmetric_impl (
679+ const std::vector<int >& input_sizes,
680+ at::ScalarType dtype,
681+ const vkcompute::utils::StorageType in_storage,
682+ const vkcompute::utils::StorageType out_storage) {
683+ std::vector<int64_t > input_sizes_int64 (
684+ input_sizes.begin (), input_sizes.end ());
685+ at::Tensor input =
686+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
687+
688+ // Calculate output sizes
689+ std::vector<int64_t > output_sizes;
690+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
691+ output_sizes.push_back (input.size (i));
692+ }
693+ output_sizes.push_back (1 );
694+
695+ // Get reference output
696+ auto [reference_scale, reference_zero_point] =
697+ torch::executor::native::choose_qparams_per_token_asymmetric_aten (
698+ input, dtype);
699+
700+ // Build Vulkan choose_qparams_per_token_asymmetric graph
701+ using namespace vkcompute ;
702+
703+ GraphConfig config;
704+ config.set_storage_type_override (in_storage);
705+ ComputeGraph graph (config);
706+
707+ IOValueRef r_input = graph.add_input_tensor (
708+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
709+
710+ // Output tensors
711+ const ValueRef r_scale =
712+ graph.add_tensor (output_sizes, vkapi::kFloat , out_storage);
713+ const ValueRef r_zero_point =
714+ graph.add_tensor (output_sizes, vkapi::kInt , out_storage);
715+
716+ VK_GET_OP_FN (" choose_qparams_per_token_asymmetric.default" )
717+ (graph,
718+ {
719+ r_input.value ,
720+ r_scale,
721+ r_zero_point,
722+ });
723+
724+ ValueRef staging_scale = graph.set_output_tensor (r_scale);
725+ ValueRef staging_zero_point = graph.set_output_tensor (r_zero_point);
726+
727+ graph.prepare ();
728+ graph.encode_prepack ();
729+ graph.prepack ();
730+ graph.encode_execute ();
731+
732+ // Run Vulkan choose_qparams_per_token_asymmetric
733+ graph.copy_into_staging (
734+ r_input.staging , input.const_data_ptr (), input.numel ());
735+
736+ graph.execute ();
737+
738+ // Create output tensors to hold the results - use types that match GPU output
739+ at::Tensor vk_scale =
740+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kFloat ))
741+ .contiguous ();
742+ at::Tensor vk_zero_point =
743+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kInt ))
744+ .contiguous ();
745+
746+ // Copy results from GPU to CPU
747+ graph.copy_from_staging (
748+ staging_scale, vk_scale.mutable_data_ptr (), vk_scale.numel ());
749+ graph.copy_from_staging (
750+ staging_zero_point,
751+ vk_zero_point.mutable_data_ptr (),
752+ vk_zero_point.numel ());
753+
754+ // Convert reference values to match Vulkan output types for comparison
755+ at::Tensor reference_scale_float = reference_scale.to (at::kFloat );
756+ at::Tensor reference_zero_point_int = reference_zero_point.to (at::kInt );
757+
758+ // Compare outputs
759+ const bool scale_correct = at::allclose (reference_scale_float, vk_scale);
760+ const bool zero_point_correct =
761+ at::equal (reference_zero_point_int, vk_zero_point);
762+ if (!scale_correct || !zero_point_correct) {
763+ std::cout << " \n "
764+ << " Failed with parameters: " << std::endl;
765+ std::cout << " storage type: "
766+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
767+ : " texture" )
768+ << std::endl;
769+
770+ if (input.numel () < 100 ) {
771+ std::cout << " input:" << std::endl;
772+ std::cout << input << " \n " << std::endl;
773+ std::cout << " reference scale:" << std::endl;
774+ std::cout << reference_scale << std::endl;
775+ std::cout << " vulkan scale:" << std::endl;
776+ std::cout << vk_scale << " \n " << std::endl;
777+ std::cout << " reference zero_point:" << std::endl;
778+ std::cout << reference_zero_point << std::endl;
779+ std::cout << " vulkan zero_point:" << std::endl;
780+ std::cout << vk_zero_point << std::endl;
781+ }
782+ }
783+
784+ ASSERT_TRUE (scale_correct && zero_point_correct);
785+ }
786+
787+ TEST (
788+ VulkanChooseQparamsTest,
789+ test_reference_choose_qparams_per_token_asymmetric_int8) {
790+ test_reference_choose_qparams_per_token_asymmetric (
791+ {2 , 3 , 4 }, // input sizes (2*3=6 tokens)
792+ at::kChar );
793+ }
0 commit comments