@@ -193,6 +193,111 @@ std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl(
193193 return std::make_tuple (scale_out, zero_point_out);
194194}
195195
196+ /*
197+ * Reference implementation of choose_qparams_per_token_asymmetric
198+ */
199+ std::tuple<at::Tensor, at::Tensor>
200+ choose_qparams_per_token_asymmetric_reference_impl (
201+ const at::Tensor& input,
202+ at::ScalarType dtype) {
203+ // For per-token quantization, we need to compute scale and zero_point for
204+ // each token
205+ int64_t quant_min = -128 ;
206+ int64_t quant_max = 127 ;
207+
208+ // Calculate output sizes
209+ std::vector<int64_t > output_sizes;
210+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
211+ output_sizes.push_back (input.size (i));
212+ }
213+ output_sizes.push_back (1 );
214+
215+ // Create output tensors
216+ at::Tensor scale_out =
217+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kDouble ));
218+ at::Tensor zero_point_out =
219+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kLong ));
220+
221+ // Calculate number of tokens
222+ int64_t num_tokens = 1 ;
223+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
224+ num_tokens *= input.size (i);
225+ }
226+
227+ // Reshape input to [num_tokens, last_dim]
228+ at::Tensor reshaped_input = input.reshape ({num_tokens, input.size (-1 )});
229+
230+ // Process each token
231+ for (int64_t token_idx = 0 ; token_idx < num_tokens; token_idx++) {
232+ at::Tensor token = reshaped_input[token_idx];
233+
234+ // Find min and max values for this token
235+ float min_val = token.min ().item <float >();
236+ float max_val = token.max ().item <float >();
237+
238+ // Extend the [min, max] interval to ensure it contains 0
239+ min_val = std::min (min_val, 0 .f );
240+ max_val = std::max (max_val, 0 .f );
241+
242+ // Calculate scale
243+ double scale =
244+ (static_cast <double >(max_val) - min_val) / (quant_max - quant_min);
245+
246+ // Handle small scale
247+ constexpr float SMALL_SCALE_THRESHOLD = 6 .1e-5f ;
248+ if (float (scale) == 0 .0f || std::isinf (1 .0f / float (scale))) {
249+ scale = 0.1 ;
250+ }
251+
252+ if (scale < SMALL_SCALE_THRESHOLD) {
253+ float org_scale = scale;
254+ scale = SMALL_SCALE_THRESHOLD;
255+ // Adjust min and max based on new scale
256+ if (min_val == 0 .0f ) {
257+ max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
258+ } else if (max_val == 0 .0f ) {
259+ min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min);
260+ } else {
261+ float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
262+ min_val *= amplifier;
263+ max_val *= amplifier;
264+ }
265+ }
266+
267+ // Calculate zero point
268+ double zero_point_from_min =
269+ quant_min - min_val / static_cast <double >(scale);
270+ double zero_point_from_max =
271+ quant_max - max_val / static_cast <double >(scale);
272+ double zero_point_from_min_error =
273+ std::abs (quant_min) - std::abs (min_val / static_cast <double >(scale));
274+ double zero_point_from_max_error =
275+ std::abs (quant_max) - std::abs (max_val / static_cast <double >(scale));
276+ double initial_zero_point =
277+ zero_point_from_min_error < zero_point_from_max_error
278+ ? zero_point_from_min
279+ : zero_point_from_max;
280+
281+ // Nudge zero point to be an integer
282+ int64_t nudged_zero_point = 0 ;
283+ if (initial_zero_point < quant_min) {
284+ nudged_zero_point = quant_min;
285+ } else if (initial_zero_point > quant_max) {
286+ nudged_zero_point = quant_max;
287+ } else {
288+ nudged_zero_point =
289+ std::nearbyint (static_cast <float >(initial_zero_point));
290+ }
291+
292+ // Set output values for this token - use index_put_ for safety
293+ scale_out.view ({num_tokens, 1 }).index_put_ ({token_idx, 0 }, scale);
294+ zero_point_out.view ({num_tokens, 1 })
295+ .index_put_ ({token_idx, 0 }, nudged_zero_point);
296+ }
297+
298+ return std::make_tuple (scale_out, zero_point_out);
299+ }
300+
196301// Forward declaration of implementation functions
197302void test_vulkan_choose_qparams_tensor_impl (
198303 const std::vector<int >& input_sizes,
@@ -202,6 +307,12 @@ void test_vulkan_choose_qparams_tensor_impl(
202307 const vkcompute::utils::StorageType in_storage,
203308 const vkcompute::utils::StorageType out_storage);
204309
310+ void test_vulkan_choose_qparams_per_token_asymmetric_impl (
311+ const std::vector<int >& input_sizes,
312+ at::ScalarType dtype,
313+ const vkcompute::utils::StorageType in_storage,
314+ const vkcompute::utils::StorageType out_storage);
315+
205316// Wrapper function to test both buffer and texture storage types
206317void test_vulkan_choose_qparams_tensor (
207318 const std::vector<int >& input_sizes,
@@ -227,6 +338,22 @@ void test_vulkan_choose_qparams_tensor(
227338 vkcompute::utils::kTexture3D );
228339}
229340
341+ // Wrapper function to test both buffer and texture storage types
342+ void test_vulkan_choose_qparams_per_token_asymmetric (
343+ const std::vector<int >& input_sizes,
344+ at::ScalarType dtype) {
345+ // Test with buffer storage
346+ test_vulkan_choose_qparams_per_token_asymmetric_impl (
347+ input_sizes, dtype, vkcompute::utils::kBuffer , vkcompute::utils::kBuffer );
348+
349+ // Test with texture storage
350+ test_vulkan_choose_qparams_per_token_asymmetric_impl (
351+ input_sizes,
352+ dtype,
353+ vkcompute::utils::kTexture3D ,
354+ vkcompute::utils::kTexture3D );
355+ }
356+
230357void test_reference_choose_qparams_tensor (
231358 const std::vector<int >& input_sizes,
232359 int64_t quant_min,
@@ -388,3 +515,161 @@ TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) {
388515 127 , // quant_max
389516 at::kChar );
390517}
518+
519+ void test_reference_choose_qparams_per_token_asymmetric (
520+ const std::vector<int >& input_sizes,
521+ at::ScalarType dtype) {
522+ std::vector<int64_t > input_sizes_int64 (
523+ input_sizes.begin (), input_sizes.end ());
524+ at::Tensor input =
525+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
526+
527+ // Get reference output
528+ auto [reference_scale, reference_zero_point] =
529+ choose_qparams_per_token_asymmetric_reference_impl (input, dtype);
530+
531+ // Get implementation output
532+ auto [impl_scale, impl_zero_point] =
533+ torch::executor::native::choose_qparams_per_token_asymmetric_aten (
534+ input, dtype);
535+
536+ // Compare outputs
537+ const bool scale_correct = at::allclose (reference_scale, impl_scale);
538+ const bool zero_point_correct =
539+ at::equal (reference_zero_point, impl_zero_point);
540+
541+ if (!scale_correct || !zero_point_correct) {
542+ std::cout << " \n "
543+ << " Failed with parameters: " << std::endl;
544+
545+ std::cout << " input:" << std::endl;
546+ std::cout << input << std::endl;
547+ std::cout << " reference scale:" << std::endl;
548+ std::cout << reference_scale << std::endl;
549+ std::cout << " implementation scale:" << std::endl;
550+ std::cout << impl_scale << std::endl;
551+ std::cout << " reference zero_point:" << std::endl;
552+ std::cout << reference_zero_point << std::endl;
553+ std::cout << " implementation zero_point:" << std::endl;
554+ std::cout << impl_zero_point << std::endl;
555+ }
556+
557+ ASSERT_TRUE (scale_correct && zero_point_correct);
558+ }
559+
560+ void test_vulkan_choose_qparams_per_token_asymmetric_impl (
561+ const std::vector<int >& input_sizes,
562+ at::ScalarType dtype,
563+ const vkcompute::utils::StorageType in_storage,
564+ const vkcompute::utils::StorageType out_storage) {
565+ std::vector<int64_t > input_sizes_int64 (
566+ input_sizes.begin (), input_sizes.end ());
567+ at::Tensor input =
568+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
569+
570+ // Calculate output sizes
571+ std::vector<int64_t > output_sizes;
572+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
573+ output_sizes.push_back (input.size (i));
574+ }
575+ output_sizes.push_back (1 );
576+
577+ // Get reference output
578+ auto [reference_scale, reference_zero_point] =
579+ torch::executor::native::choose_qparams_per_token_asymmetric_aten (
580+ input, dtype);
581+
582+ // Build Vulkan choose_qparams_per_token_asymmetric graph
583+ using namespace vkcompute ;
584+
585+ GraphConfig config;
586+ config.set_storage_type_override (in_storage);
587+ ComputeGraph graph (config);
588+
589+ IOValueRef r_input = graph.add_input_tensor (
590+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
591+
592+ // Output tensors
593+ const ValueRef r_scale =
594+ graph.add_tensor (output_sizes, vkapi::kFloat , out_storage);
595+ const ValueRef r_zero_point =
596+ graph.add_tensor (output_sizes, vkapi::kInt , out_storage);
597+
598+ VK_GET_OP_FN (" choose_qparams_per_token_asymmetric.default" )
599+ (graph,
600+ {
601+ r_input.value ,
602+ r_scale,
603+ r_zero_point,
604+ });
605+
606+ ValueRef staging_scale = graph.set_output_tensor (r_scale);
607+ ValueRef staging_zero_point = graph.set_output_tensor (r_zero_point);
608+
609+ graph.prepare ();
610+ graph.encode_prepack ();
611+ graph.prepack ();
612+ graph.encode_execute ();
613+
614+ // Run Vulkan choose_qparams_per_token_asymmetric
615+ graph.copy_into_staging (
616+ r_input.staging , input.const_data_ptr (), input.numel ());
617+
618+ graph.execute ();
619+
620+ // Create output tensors to hold the results - use types that match GPU output
621+ at::Tensor vk_scale =
622+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kFloat ))
623+ .contiguous ();
624+ at::Tensor vk_zero_point =
625+ at::empty (output_sizes, at::device (at::kCPU ).dtype (at::kInt ))
626+ .contiguous ();
627+
628+ // Copy results from GPU to CPU
629+ graph.copy_from_staging (
630+ staging_scale, vk_scale.mutable_data_ptr (), vk_scale.numel ());
631+ graph.copy_from_staging (
632+ staging_zero_point,
633+ vk_zero_point.mutable_data_ptr (),
634+ vk_zero_point.numel ());
635+
636+ // Convert reference values to match Vulkan output types for comparison
637+ at::Tensor reference_scale_float = reference_scale.to (at::kFloat );
638+ at::Tensor reference_zero_point_int = reference_zero_point.to (at::kInt );
639+
640+ // Compare outputs
641+ const bool scale_correct = at::allclose (reference_scale_float, vk_scale);
642+ const bool zero_point_correct =
643+ at::equal (reference_zero_point_int, vk_zero_point);
644+ if (!scale_correct || !zero_point_correct) {
645+ std::cout << " \n "
646+ << " Failed with parameters: " << std::endl;
647+ std::cout << " storage type: "
648+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
649+ : " texture" )
650+ << std::endl;
651+
652+ if (input.numel () < 100 ) {
653+ std::cout << " input:" << std::endl;
654+ std::cout << input << " \n " << std::endl;
655+ std::cout << " reference scale:" << std::endl;
656+ std::cout << reference_scale << std::endl;
657+ std::cout << " vulkan scale:" << std::endl;
658+ std::cout << vk_scale << " \n " << std::endl;
659+ std::cout << " reference zero_point:" << std::endl;
660+ std::cout << reference_zero_point << std::endl;
661+ std::cout << " vulkan zero_point:" << std::endl;
662+ std::cout << vk_zero_point << std::endl;
663+ }
664+ }
665+
666+ ASSERT_TRUE (scale_correct && zero_point_correct);
667+ }
668+
669+ TEST (
670+ VulkanChooseQparamsTest,
671+ test_reference_choose_qparams_per_token_asymmetric_int8) {
672+ test_reference_choose_qparams_per_token_asymmetric (
673+ {2 , 3 , 4 }, // input sizes (2*3=6 tokens)
674+ at::kChar );
675+ }
0 commit comments