@@ -827,6 +827,145 @@ void test_reference_dequantize_per_tensor(
827827 ASSERT_TRUE (output_correct);
828828}
829829
830+ void test_vulkan_dequantize_per_tensor_impl (
831+ const std::vector<int >& input_sizes,
832+ float scale,
833+ int zero_point,
834+ int64_t quant_min,
835+ int64_t quant_max,
836+ at::ScalarType dtype,
837+ at::ScalarType out_dtype,
838+ const vkcompute::utils::StorageType in_storage,
839+ const vkcompute::utils::StorageType out_storage) {
840+ check_dequantize_args (quant_min, quant_max, dtype, out_dtype);
841+ std::vector<int64_t > input_sizes_int64 (
842+ input_sizes.begin (), input_sizes.end ());
843+
844+ // Create a quantized input tensor with values from quant_min to quant_max
845+ at::Tensor input;
846+ if (dtype == at::kByte ) {
847+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kByte ));
848+ } else if (dtype == at::kChar ) {
849+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kChar ));
850+ } else if (dtype == at::kShort ) {
851+ input =
852+ at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kShort ));
853+ } else if (dtype == at::kInt ) {
854+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kInt ));
855+ } else {
856+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kLong ));
857+ }
858+
859+ // Fill with a simple pattern: values from quant_min to quant_max in steps
860+ float step = 1 .0f ;
861+ if (input.numel () > 1 ) {
862+ step = static_cast <float >(quant_max - quant_min) / (input.numel () - 1 );
863+ }
864+
865+ auto flat_input = input.flatten ();
866+ for (int i = 0 ; i < flat_input.numel (); i++) {
867+ int64_t qvalue = quant_min + i * step;
868+ if (dtype == at::kByte ) {
869+ flat_input[i] = static_cast <uint8_t >(qvalue);
870+ } else if (dtype == at::kChar ) {
871+ flat_input[i] = static_cast <int8_t >(qvalue);
872+ } else if (dtype == at::kShort ) {
873+ flat_input[i] = static_cast <int16_t >(qvalue);
874+ } else if (dtype == at::kInt ) {
875+ flat_input[i] = static_cast <int32_t >(qvalue);
876+ } else if (dtype == at::kLong ) {
877+ flat_input[i] = static_cast <int64_t >(qvalue);
878+ }
879+ }
880+
881+ // Reshape back to original dimensions
882+ input = flat_input.reshape (input_sizes_int64);
883+
884+ // Get reference output
885+ at::Tensor reference_out =
886+ torch::executor::native::dequantize_per_tensor_aten (
887+ input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
888+
889+ // Build Vulkan dequantize_per_tensor graph
890+ using namespace vkcompute ;
891+
892+ GraphConfig config;
893+ config.set_storage_type_override (in_storage);
894+ ComputeGraph graph (config);
895+
896+ IOValueRef r_input = graph.add_input_tensor (
897+ input.sizes ().vec (), from_at_scalartype (dtype), in_storage);
898+
899+ const ValueRef r_scale = graph.add_scalar <double >(scale);
900+ const ValueRef r_zero_point = graph.add_scalar <int64_t >(zero_point);
901+ const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
902+ const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
903+
904+ const ValueRef r_out = graph.add_tensor (
905+ input.sizes ().vec (), from_at_scalartype (out_dtype), out_storage);
906+
907+ VK_GET_OP_FN (" dequantize_per_tensor.default" )
908+ (graph,
909+ {
910+ r_input.value ,
911+ r_scale,
912+ r_zero_point,
913+ r_quant_min,
914+ r_quant_max,
915+ r_out,
916+ });
917+
918+ ValueRef staging_out = graph.set_output_tensor (r_out);
919+
920+ graph.prepare ();
921+
922+ graph.prepack ();
923+ graph.encode_execute ();
924+
925+ // Run Vulkan dequantize_per_tensor
926+ graph.copy_into_staging (
927+ r_input.staging , input.const_data_ptr (), input.numel ());
928+
929+ graph.execute ();
930+
931+ at::Tensor vk_out = at::empty_like (reference_out).contiguous ();
932+ graph.copy_from_staging (
933+ staging_out, vk_out.mutable_data_ptr (), vk_out.numel ());
934+
935+ // Compare outputs with appropriate tolerance for half precision
936+ bool output_correct;
937+ if (out_dtype == at::kHalf ) {
938+ // Use higher tolerance for half precision due to limited precision
939+ output_correct =
940+ at::allclose (reference_out, vk_out, /* rtol=*/ 1e-2 , /* atol=*/ 1e-2 );
941+ } else {
942+ output_correct = at::allclose (reference_out, vk_out);
943+ }
944+ if (!output_correct) {
945+ std::cout << " \n "
946+ << " Failed with parameters: " << std::endl;
947+ std::cout << " scale: " << scale << std::endl;
948+ std::cout << " zero_point: " << zero_point << std::endl;
949+ std::cout << " quant_min: " << quant_min << std::endl;
950+ std::cout << " quant_max: " << quant_max << std::endl;
951+ std::cout << " storage type: "
952+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
953+ : " texture" )
954+ << std::endl;
955+ std::cout << " input dtype: " << dtype << std::endl;
956+ std::cout << " output dtype: " << out_dtype << std::endl;
957+
958+ std::cout << " input:" << std::endl;
959+ std::cout << input << std::endl;
960+ std::cout << " reference:" << std::endl;
961+ std::cout << reference_out << std::endl;
962+ std::cout << " vulkan:" << std::endl;
963+ std::cout << vk_out << std::endl;
964+ }
965+
966+ ASSERT_TRUE (output_correct);
967+ }
968+
830969TEST (
831970 VulkanDequantizePerTensorTest,
832971 test_reference_dequantize_per_tensor_uint8_to_float) {
@@ -1138,7 +1277,7 @@ void test_vulkan_dequantize_per_token_impl(
11381277 ValueRef staging_out = graph.set_output_tensor (r_out);
11391278
11401279 graph.prepare ();
1141- graph. encode_prepack ();
1280+
11421281 graph.prepack ();
11431282 graph.encode_execute ();
11441283
@@ -1670,7 +1809,6 @@ void test_vulkan_dequantize_per_channel_impl(
16701809 ValueRef staging_out = graph.set_output_tensor (r_out);
16711810
16721811 graph.prepare ();
1673- graph.encode_prepack ();
16741812 graph.prepack ();
16751813 graph.encode_execute ();
16761814
@@ -2345,7 +2483,6 @@ void test_vulkan_dequantize_per_tensor_tensor_impl(
23452483 ValueRef staging_out = graph.set_output_tensor (r_out);
23462484
23472485 graph.prepare ();
2348- graph.encode_prepack ();
23492486 graph.prepack ();
23502487 graph.encode_execute ();
23512488
0 commit comments