@@ -275,3 +275,330 @@ void check_quantize_args(
275275 " actual quant_max: " ,
276276 quant_max);
277277}
278+ /*
279+ * Reference implementation of quantize_per_token
280+ */
281+ at::Tensor quantize_per_token_reference_impl (
282+ const at::Tensor& input,
283+ const at::Tensor& scale,
284+ const at::Tensor& zero_point,
285+ int64_t quant_min,
286+ int64_t quant_max,
287+ at::ScalarType dtype) {
288+ // Create output tensor with the target dtype
289+ at::Tensor out = at::empty_like (input, dtype);
290+
291+ // Calculate number of tokens
292+ int num_tokens = 1 ;
293+ for (int i = 0 ; i < input.dim () - 1 ; i++) {
294+ num_tokens *= input.size (i);
295+ }
296+
297+ // Verify that the number of tokens matches the size of scale and zero_point
298+ // tensors
299+ assert (num_tokens == scale.numel ());
300+ assert (num_tokens == zero_point.numel ());
301+
302+ // Reshape input to [num_tokens, last_dim]
303+ at::Tensor reshaped_input = input.reshape ({num_tokens, input.size (-1 )});
304+ at::Tensor reshaped_out = out.reshape ({num_tokens, input.size (-1 )});
305+
306+ // Quantize each token separately
307+ for (int token_idx = 0 ; token_idx < num_tokens; token_idx++) {
308+ // Use float for scale since Vulkan doesn't support double
309+ float token_scale = scale[token_idx].item <float >();
310+ // Use int for zero_point since Vulkan doesn't support int64_t
311+ int token_zero_point = zero_point[token_idx].item <int >();
312+
313+ float inv_scale = 1.0 / token_scale;
314+
315+ // Quantize the token
316+ for (int i = 0 ; i < input.size (-1 ); i++) {
317+ float value = reshaped_input[token_idx][i].item <float >();
318+ int qvalue = token_zero_point + std::nearbyint (inv_scale * value);
319+
320+ qvalue = std::max<int64_t >(qvalue, quant_min);
321+ qvalue = std::min<int64_t >(qvalue, quant_max);
322+
323+ if (dtype == at::kByte ) {
324+ reshaped_out[token_idx][i] = static_cast <uint8_t >(qvalue);
325+ } else if (dtype == at::kChar ) {
326+ reshaped_out[token_idx][i] = static_cast <int8_t >(qvalue);
327+ } else if (dtype == at::kShort ) {
328+ reshaped_out[token_idx][i] = static_cast <int16_t >(qvalue);
329+ } else if (dtype == at::kInt ) {
330+ reshaped_out[token_idx][i] = static_cast <int32_t >(qvalue);
331+ } else if (dtype == at::kLong ) {
332+ reshaped_out[token_idx][i] = static_cast <int64_t >(qvalue);
333+ }
334+ }
335+ }
336+
337+ return out;
338+ }
339+
340+ void test_vulkan_quantize_per_token_impl (
341+ const std::vector<int >& input_sizes,
342+ const std::vector<float >& scales,
343+ const std::vector<int >& zero_points,
344+ int64_t quant_min,
345+ int64_t quant_max,
346+ at::ScalarType dtype,
347+ const vkcompute::utils::StorageType in_storage,
348+ const vkcompute::utils::StorageType out_storage);
349+
350+ // Wrapper function to test both buffer and texture storage types
351+ void test_vulkan_quantize_per_token (
352+ const std::vector<int >& input_sizes,
353+ const std::vector<float >& scales,
354+ const std::vector<int >& zero_points,
355+ int64_t quant_min,
356+ int64_t quant_max,
357+ at::ScalarType dtype) {
358+ // Test with buffer storage
359+ test_vulkan_quantize_per_token_impl (
360+ input_sizes,
361+ scales,
362+ zero_points,
363+ quant_min,
364+ quant_max,
365+ dtype,
366+ vkcompute::utils::kBuffer ,
367+ vkcompute::utils::kBuffer );
368+
369+ // Test with texture storage
370+ test_vulkan_quantize_per_token_impl (
371+ input_sizes,
372+ scales,
373+ zero_points,
374+ quant_min,
375+ quant_max,
376+ dtype,
377+ vkcompute::utils::kTexture3D ,
378+ vkcompute::utils::kTexture3D );
379+ }
380+
381+ void test_reference_quantize_per_token (
382+ const std::vector<int >& input_sizes,
383+ const std::vector<float >& scales,
384+ const std::vector<int >& zero_points,
385+ int64_t quant_min,
386+ int64_t quant_max,
387+ at::ScalarType dtype) {
388+ check_quantize_args (quant_min, quant_max, dtype);
389+ std::vector<int64_t > input_sizes_int64 (
390+ input_sizes.begin (), input_sizes.end ());
391+ at::Tensor input =
392+ at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
393+
394+ // Fill with a simple pattern: values from 0 to 1 in steps
395+ float step = 1.0 / (input.numel () - 1 );
396+ auto flat_input = input.flatten ();
397+ for (int i = 0 ; i < flat_input.numel (); i++) {
398+ flat_input[i] = i * step;
399+ }
400+
401+ // Reshape back to original dimensions
402+ input = flat_input.reshape (input_sizes_int64);
403+
404+ // Calculate number of tokens
405+ int num_tokens = 1 ;
406+ for (int i = 0 ; i < input.dim () - 1 ; i++) {
407+ num_tokens *= input.size (i);
408+ }
409+
410+ // Verify that the number of tokens matches the size of scales and zero_points
411+ ASSERT_EQ (num_tokens, scales.size ());
412+ ASSERT_EQ (num_tokens, zero_points.size ());
413+
414+ // Create scale and zero_point tensors
415+ at::Tensor scale_tensor =
416+ at::tensor (scales, at::device (at::kCPU ).dtype (at::kDouble ));
417+ at::Tensor zero_point_tensor =
418+ at::tensor (zero_points, at::device (at::kCPU ).dtype (at::kLong ));
419+
420+ // Get reference output
421+ at::Tensor reference_out = quantize_per_token_reference_impl (
422+ input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype);
423+
424+ // Get implementation output
425+ at::Tensor impl_out = torch::executor::native::quantize_per_token_aten (
426+ input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype);
427+
428+ // Convert to int for consistent display regardless of underlying type
429+ at::Tensor reference_int = reference_out.to (at::kInt );
430+ at::Tensor impl_int = impl_out.to (at::kInt );
431+
432+ const bool output_correct = at::equal (reference_int, impl_out);
433+ if (!output_correct) {
434+ std::cout << " \n "
435+ << " Failed with parameters: " << std::endl;
436+ std::cout << " scale(s):" ;
437+ for (size_t i = 0 ; i < scales.size (); i++) {
438+ std::cout << " " << scales[i] << " " ;
439+ }
440+ std::cout << " " << std::endl;
441+ std::cout << " zero_point(s):" ;
442+ for (size_t i = 0 ; i < zero_points.size (); i++) {
443+ std::cout << " " << zero_points[i] << " " ;
444+ }
445+ std::cout << " " << std::endl;
446+ std::cout << " quant_min: " << quant_min << std::endl;
447+ std::cout << " quant_max: " << quant_max << std::endl;
448+
449+ std::cout << " input:" << std::endl;
450+ std::cout << input << std::endl;
451+ std::cout << " reference:" << std::endl;
452+ std::cout << reference_int << std::endl;
453+ std::cout << " my_reference:" << std::endl;
454+ std::cout << impl_out << std::endl;
455+ }
456+
457+ ASSERT_TRUE (output_correct);
458+ }
459+
460+ void test_vulkan_quantize_per_token_impl (
461+ const std::vector<int >& input_sizes,
462+ const std::vector<float >& scales,
463+ const std::vector<int >& zero_points,
464+ int64_t quant_min,
465+ int64_t quant_max,
466+ at::ScalarType dtype,
467+ const vkcompute::utils::StorageType in_storage =
468+ vkcompute::utils::kTexture3D ,
469+ const vkcompute::utils::StorageType out_storage =
470+ vkcompute::utils::kTexture3D ) {
471+ check_quantize_args (quant_min, quant_max, dtype);
472+ int num_tokens = 1 ;
473+ for (int i = 0 ; i < input_sizes.size () - 1 ; i++) {
474+ num_tokens *= input_sizes[i];
475+ }
476+
477+ ASSERT_EQ (num_tokens, scales.size ());
478+ ASSERT_EQ (num_tokens, zero_points.size ());
479+
480+ // Create input tensor with random values
481+ std::vector<int64_t > input_sizes_int64 (
482+ input_sizes.begin (), input_sizes.end ());
483+ at::Tensor input =
484+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
485+ at::Tensor scale_tensor =
486+ at::tensor (scales, at::device (at::kCPU ).dtype (at::kDouble ));
487+ at::Tensor zero_point_tensor =
488+ at::tensor (zero_points, at::device (at::kCPU ).dtype (at::kLong ));
489+
490+ // Get reference output to show what we would compare against
491+ at::Tensor reference_out = torch::executor::native::quantize_per_token_aten (
492+ input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype);
493+
494+ using namespace vkcompute ;
495+
496+ GraphConfig config;
497+ config.set_storage_type_override (in_storage);
498+ ComputeGraph graph (config);
499+
500+ IOValueRef r_input = graph.add_input_tensor (
501+ input.sizes ().vec (), from_at_scalartype (input.scalar_type ()), in_storage);
502+ IOValueRef r_scale = graph.add_input_tensor (
503+ scale_tensor.sizes ().vec (), vkapi::kFloat , in_storage);
504+ IOValueRef r_zero_point = graph.add_input_tensor (
505+ zero_point_tensor.sizes ().vec (), vkapi::kInt , in_storage);
506+
507+ const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
508+ const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
509+
510+ const ValueRef r_out = graph.add_tensor (
511+ input.sizes ().vec (), from_at_scalartype (dtype), out_storage);
512+
513+ VK_GET_OP_FN (" quantize_per_token.default" )
514+ (graph,
515+ {
516+ r_input.value ,
517+ r_scale.value ,
518+ r_zero_point.value ,
519+ r_quant_min,
520+ r_quant_max,
521+ r_out,
522+ });
523+
524+ ValueRef staging_out = graph.set_output_tensor (r_out);
525+
526+ graph.prepare ();
527+ graph.encode_prepack ();
528+ graph.prepack ();
529+ graph.encode_execute ();
530+
531+ // Copy input data to GPU
532+ graph.copy_into_staging (
533+ r_input.staging , input.const_data_ptr (), input.numel ());
534+
535+ // Convert scale tensor to float and copy to GPU
536+ at::Tensor scale_float = scale_tensor.to (at::kFloat );
537+ graph.copy_into_staging (
538+ r_scale.staging , scale_float.const_data_ptr (), scale_float.numel ());
539+
540+ // Convert zero_point tensor to int and copy to GPU
541+ at::Tensor zero_point_int = zero_point_tensor.to (at::kInt );
542+ graph.copy_into_staging (
543+ r_zero_point.staging ,
544+ zero_point_int.const_data_ptr (),
545+ zero_point_int.numel ());
546+
547+ // Execute the graph
548+ graph.execute ();
549+
550+ // Copy output data back to CPU
551+ at::Tensor vk_out = at::empty_like (reference_out).contiguous ();
552+ graph.copy_from_staging (
553+ staging_out, vk_out.mutable_data_ptr (), vk_out.numel ());
554+
555+ // Compare outputs
556+ at::Tensor reference_int = reference_out.to (at::kInt );
557+ at::Tensor vk_int = vk_out.to (at::kInt );
558+
559+ const bool output_correct = at::equal (reference_int, vk_int);
560+ if (!output_correct) {
561+ at::Tensor diffs = at::abs (reference_int - vk_int);
562+
563+ std::cout << " \n "
564+ << " Failed with parameters: " << std::endl;
565+ std::cout << " scale(s):" ;
566+ for (size_t i = 0 ; i < scales.size (); i++) {
567+ std::cout << " " << scales[i] << " " ;
568+ }
569+ std::cout << " " << std::endl;
570+ std::cout << " zero_point(s):" ;
571+ for (size_t i = 0 ; i < zero_points.size (); i++) {
572+ std::cout << " " << zero_points[i] << " " ;
573+ }
574+ std::cout << " " << std::endl;
575+ std::cout << " quant_min: " << quant_min << std::endl;
576+ std::cout << " quant_max: " << quant_max << std::endl;
577+ std::cout << " storage type: "
578+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
579+ : " texture" )
580+ << std::endl;
581+
582+ std::cout << " input:" << std::endl;
583+ std::cout << input << std::endl;
584+ std::cout << " reference:" << std::endl;
585+ std::cout << reference_int << std::endl;
586+ std::cout << " vulkan:" << std::endl;
587+ std::cout << vk_int << std::endl;
588+ }
589+
590+ ASSERT_TRUE (output_correct);
591+ }
592+
593+ TEST (VulkanQuantizePerTensorTest, test_reference_quantize_per_token_int8) {
594+ std::vector<float > scales = {0.1 , 0 , 0.3 , 0.1 , 0.2 , 0.3 };
595+ std::vector<int > zero_points = {1 , 2 , 3 , 0 , -1 , -2 };
596+
597+ test_reference_quantize_per_token (
598+ {2 , 3 , 4 }, // input sizes (2*3=6 tokens)
599+ scales,
600+ zero_points,
601+ -128 , // quant_min
602+ 127 , // quant_max
603+ at::kChar );
604+ }
0 commit comments