@@ -294,3 +294,328 @@ void check_dequantize_args(
294294 " )" );
295295 }
296296}
297+
298+ //
299+ // Reference Implementation
300+ //
301+
302+ /*
303+ * Reference implementation of dequantize_per_tensor
304+ */
305+ at::Tensor dequantize_per_tensor_reference_impl (
306+ const at::Tensor& input,
307+ double scale,
308+ int64_t zero_point,
309+ int64_t quant_min,
310+ int64_t quant_max,
311+ at::ScalarType dtype,
312+ at::ScalarType out_dtype) {
313+ // Create output tensor with the target dtype
314+ at::Tensor out = at::empty_like (input, out_dtype);
315+
316+ // Dequantize the input tensor
317+ at::Tensor int_input = input.to (at::kInt );
318+ at::Tensor flat_input = int_input.flatten ();
319+ at::Tensor flat_out = out.flatten ();
320+
321+ for (int i = 0 ; i < flat_input.numel (); i++) {
322+ int64_t qvalue = flat_input[i].item <int64_t >();
323+ float value = static_cast <float >((qvalue - zero_point) * scale);
324+
325+ if (out_dtype == at::kFloat ) {
326+ flat_out[i] = value;
327+ } else if (out_dtype == at::kDouble ) {
328+ flat_out[i] = static_cast <double >(value);
329+ }
330+ }
331+
332+ return out.reshape (input.sizes ());
333+ }
334+
335+ // Forward declaration of implementation functions
336+ void test_vulkan_dequantize_per_tensor_impl (
337+ const std::vector<int >& input_sizes,
338+ float scale,
339+ int zero_point,
340+ int64_t quant_min,
341+ int64_t quant_max,
342+ at::ScalarType dtype,
343+ at::ScalarType out_dtype,
344+ const vkcompute::utils::StorageType in_storage,
345+ const vkcompute::utils::StorageType out_storage);
346+
347+ // Wrapper function to test both buffer and texture storage types
348+ void test_vulkan_dequantize_per_tensor (
349+ const std::vector<int >& input_sizes,
350+ float scale,
351+ int zero_point,
352+ int64_t quant_min,
353+ int64_t quant_max,
354+ at::ScalarType dtype,
355+ at::ScalarType out_dtype) {
356+ // Test with buffer storage
357+ test_vulkan_dequantize_per_tensor_impl (
358+ input_sizes,
359+ scale,
360+ zero_point,
361+ quant_min,
362+ quant_max,
363+ dtype,
364+ out_dtype,
365+ vkcompute::utils::kBuffer ,
366+ vkcompute::utils::kBuffer );
367+
368+ // Test with texture storage
369+ test_vulkan_dequantize_per_tensor_impl (
370+ input_sizes,
371+ scale,
372+ zero_point,
373+ quant_min,
374+ quant_max,
375+ dtype,
376+ out_dtype,
377+ vkcompute::utils::kTexture3D ,
378+ vkcompute::utils::kTexture3D );
379+ }
380+
381+ void test_reference_dequantize_per_tensor (
382+ const std::vector<int >& input_sizes,
383+ float scale,
384+ int zero_point,
385+ int64_t quant_min,
386+ int64_t quant_max,
387+ at::ScalarType dtype,
388+ at::ScalarType out_dtype) {
389+ check_dequantize_args (quant_min, quant_max, dtype, out_dtype);
390+ std::vector<int64_t > input_sizes_int64 (
391+ input_sizes.begin (), input_sizes.end ());
392+
393+ // Create a quantized input tensor with values from quant_min to quant_max
394+ at::Tensor input;
395+ if (dtype == at::kByte ) {
396+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kByte ));
397+ } else if (dtype == at::kChar ) {
398+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kChar ));
399+ } else if (dtype == at::kShort ) {
400+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kShort ));
401+ } else if (dtype == at::kInt ) {
402+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kInt ));
403+ } else {
404+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kLong ));
405+ }
406+
407+ // Fill with a simple pattern: values from quant_min to quant_max in steps
408+ float step = 1 .0f ;
409+ if (input.numel () > 1 ) {
410+ step = static_cast <float >(quant_max - quant_min) / (input.numel () - 1 );
411+ }
412+
413+ auto flat_input = input.flatten ();
414+ for (int i = 0 ; i < flat_input.numel (); i++) {
415+ int64_t qvalue = quant_min + i * step;
416+ if (dtype == at::kByte ) {
417+ flat_input[i] = static_cast <uint8_t >(qvalue);
418+ } else if (dtype == at::kChar ) {
419+ flat_input[i] = static_cast <int8_t >(qvalue);
420+ } else if (dtype == at::kShort ) {
421+ flat_input[i] = static_cast <int16_t >(qvalue);
422+ } else if (dtype == at::kInt ) {
423+ flat_input[i] = static_cast <int32_t >(qvalue);
424+ } else if (dtype == at::kLong ) {
425+ flat_input[i] = static_cast <int64_t >(qvalue);
426+ }
427+ }
428+
429+ // Reshape back to original dimensions
430+ input = flat_input.reshape (input_sizes_int64);
431+
432+ // Get reference output
433+ at::Tensor reference_out = dequantize_per_tensor_reference_impl (
434+ input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
435+
436+ // Get implementation output
437+ at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten (
438+ input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
439+
440+ // Compare outputs
441+ const bool output_correct = at::allclose (reference_out, impl_out, 1e-5 , 1e-5 );
442+ if (!output_correct) {
443+ std::cout << " \n "
444+ << " Failed with parameters: " << std::endl;
445+ std::cout << " scale: " << scale << std::endl;
446+ std::cout << " zero_point: " << zero_point << std::endl;
447+ std::cout << " quant_min: " << quant_min << std::endl;
448+ std::cout << " quant_max: " << quant_max << std::endl;
449+
450+ std::cout << " input:" << std::endl;
451+ std::cout << input << std::endl;
452+ std::cout << " reference:" << std::endl;
453+ std::cout << reference_out << std::endl;
454+ std::cout << " implementation:" << std::endl;
455+ std::cout << impl_out << std::endl;
456+ }
457+
458+ ASSERT_TRUE (output_correct);
459+ }
460+
461+ void test_vulkan_dequantize_per_tensor_impl (
462+ const std::vector<int >& input_sizes,
463+ float scale,
464+ int zero_point,
465+ int64_t quant_min,
466+ int64_t quant_max,
467+ at::ScalarType dtype,
468+ at::ScalarType out_dtype,
469+ const vkcompute::utils::StorageType in_storage,
470+ const vkcompute::utils::StorageType out_storage) {
471+ check_dequantize_args (quant_min, quant_max, dtype, out_dtype);
472+ std::vector<int64_t > input_sizes_int64 (
473+ input_sizes.begin (), input_sizes.end ());
474+
475+ // Create a quantized input tensor with values from quant_min to quant_max
476+ at::Tensor input;
477+ if (dtype == at::kByte ) {
478+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kByte ));
479+ } else if (dtype == at::kChar ) {
480+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kChar ));
481+ } else if (dtype == at::kShort ) {
482+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kShort ));
483+ } else if (dtype == at::kInt ) {
484+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kInt ));
485+ } else {
486+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kLong ));
487+ }
488+
489+ // Fill with a simple pattern: values from quant_min to quant_max in steps
490+ float step = 1 .0f ;
491+ if (input.numel () > 1 ) {
492+ step = static_cast <float >(quant_max - quant_min) / (input.numel () - 1 );
493+ }
494+
495+ auto flat_input = input.flatten ();
496+ for (int i = 0 ; i < flat_input.numel (); i++) {
497+ int64_t qvalue = quant_min + i * step;
498+ if (dtype == at::kByte ) {
499+ flat_input[i] = static_cast <uint8_t >(qvalue);
500+ } else if (dtype == at::kChar ) {
501+ flat_input[i] = static_cast <int8_t >(qvalue);
502+ } else if (dtype == at::kShort ) {
503+ flat_input[i] = static_cast <int16_t >(qvalue);
504+ } else if (dtype == at::kInt ) {
505+ flat_input[i] = static_cast <int32_t >(qvalue);
506+ } else if (dtype == at::kLong ) {
507+ flat_input[i] = static_cast <int64_t >(qvalue);
508+ }
509+ }
510+
511+ // Reshape back to original dimensions
512+ input = flat_input.reshape (input_sizes_int64);
513+
514+ // Get reference output
515+ at::Tensor reference_out = torch::executor::native::dequantize_per_tensor_aten (
516+ input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
517+
518+ // Build Vulkan dequantize_per_tensor graph
519+ using namespace vkcompute ;
520+
521+ GraphConfig config;
522+ config.set_storage_type_override (in_storage);
523+ ComputeGraph graph (config);
524+
525+ IOValueRef r_input = graph.add_input_tensor (
526+ input.sizes ().vec (), from_at_scalartype (dtype), in_storage);
527+
528+ const ValueRef r_scale = graph.add_scalar <double >(scale);
529+ const ValueRef r_zero_point = graph.add_scalar <int64_t >(zero_point);
530+ const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
531+ const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
532+
533+ const ValueRef r_out = graph.add_tensor (
534+ input.sizes ().vec (), from_at_scalartype (out_dtype), out_storage);
535+
536+ VK_GET_OP_FN (" dequantize_per_tensor.default" )
537+ (graph,
538+ {
539+ r_input.value ,
540+ r_scale,
541+ r_zero_point,
542+ r_quant_min,
543+ r_quant_max,
544+ r_out,
545+ });
546+
547+ ValueRef staging_out = graph.set_output_tensor (r_out);
548+
549+ graph.prepare ();
550+ graph.encode_prepack ();
551+ graph.prepack ();
552+ graph.encode_execute ();
553+
554+ // Run Vulkan dequantize_per_tensor
555+ graph.copy_into_staging (
556+ r_input.staging , input.const_data_ptr (), input.numel ());
557+
558+ graph.execute ();
559+
560+ at::Tensor vk_out = at::empty_like (reference_out).contiguous ();
561+ graph.copy_from_staging (
562+ staging_out, vk_out.mutable_data_ptr (), vk_out.numel ());
563+
564+ // Compare outputs
565+ const bool output_correct = at::allclose (reference_out, vk_out, 1e-5 , 1e-5 );
566+ if (!output_correct) {
567+ std::cout << " \n "
568+ << " Failed with parameters: " << std::endl;
569+ std::cout << " scale: " << scale << std::endl;
570+ std::cout << " zero_point: " << zero_point << std::endl;
571+ std::cout << " quant_min: " << quant_min << std::endl;
572+ std::cout << " quant_max: " << quant_max << std::endl;
573+ std::cout << " storage type: "
574+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
575+ : " texture" )
576+ << std::endl;
577+
578+ std::cout << " input:" << std::endl;
579+ std::cout << input << std::endl;
580+ std::cout << " reference:" << std::endl;
581+ std::cout << reference_out << std::endl;
582+ std::cout << " vulkan:" << std::endl;
583+ std::cout << vk_out << std::endl;
584+ }
585+
586+ ASSERT_TRUE (output_correct);
587+ }
588+
589+ // Test cases for dequantize_per_tensor
590+ TEST (VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_uint8_to_float) {
591+ test_reference_dequantize_per_tensor (
592+ {2 , 3 , 4 }, // input sizes
593+ 0.1 , // scale
594+ 5 , // zero_point
595+ 0 , // quant_min
596+ 255 , // quant_max
597+ at::kByte , // input dtype
598+ at::kFloat ); // output dtype
599+ }
600+
601+ TEST (VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_int8_to_float) {
602+ test_reference_dequantize_per_tensor (
603+ {3 , 4 , 5 }, // input sizes
604+ 0.05 , // scale
605+ 0 , // zero_point
606+ -128 , // quant_min
607+ 127 , // quant_max
608+ at::kChar , // input dtype
609+ at::kFloat ); // output dtype
610+ }
611+
612+ TEST (VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_int16_to_float) {
613+ test_reference_dequantize_per_tensor (
614+ {2 , 2 , 3 }, // input sizes
615+ 0.001 , // scale
616+ -10 , // zero_point
617+ -32768 , // quant_min
618+ 32767 , // quant_max
619+ at::kShort , // input dtype
620+ at::kFloat ); // output dtype
621+ }
0 commit comments