diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 6f8fdead5b7..b4c4ac274dc 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -744,49 +744,6 @@ void test_vulkan_dequantize_per_tensor_tensor( vkcompute::utils::kTexture3D); } -// Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_channel( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Test with buffer storage - test_vulkan_dequantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Telling the system to expect a float instead of a double - // since the shader can only return 32bit anyways - if (out_dtype == at::kDouble) { - out_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_dequantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - void test_reference_dequantize_per_tensor( const std::vector& input_sizes, float scale, diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 04adf183e55..926452dd388 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1964,3 +1964,99 @@ def forward(self, x): GroupNormModule(num_groups, num_channels), sample_inputs, ) + + def test_vulkan_backend_full_quantization_workflow(self): + class FullQuantizationWorkflowModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # Step 1: Choose quantization parameters per tensor + scale, zero_point = ( + torch.ops.quantized_decomposed.choose_qparams.tensor( + x, + quant_min=-2147483648, # int32 min + quant_max=2147483647, # int32 max + eps=1e-5, + dtype=torch.int32, + ) + ) + + # Step 2: Quantize using the calculated parameters + quantized = torch.ops.quantized_decomposed.quantize_per_tensor.tensor( + x, + scale, + zero_point, + quant_min=-2147483648, # int32 min + quant_max=2147483647, # int32 max + dtype=torch.int32, + ) + + # Step 3: Dequantize back to float + dequantized = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor( + quantized, + scale, + zero_point, + quant_min=-2147483648, # int32 min + quant_max=2147483647, # int32 max + dtype=torch.int32, + ) + ) + + return dequantized + + full_workflow_module = FullQuantizationWorkflowModule() + sample_inputs = (torch.rand(size=(2, 3, 4), dtype=torch.float32),) + + # Use higher tolerance since quantization introduces some error + self.lower_module_and_test_output( + full_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 + ) + + def test_vulkan_backend_full_per_token_quantization_workflow(self): + class FullPerTokenQuantizationWorkflowModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # Step 1: Choose quantization parameters per token + scale, zero_point = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( + x, + dtype=torch.int32, + ) + ) + + # Step 2: Quantize using the calculated parameters per token + quantized = torch.ops.quantized_decomposed.quantize_per_token.default( + x, + scale, + zero_point, + quant_min=-2147483648, # int32 min + quant_max=2147483647, # int32 max + dtype=torch.int32, + ) + + # Step 3: Dequantize back to float per token + dequantized = ( + torch.ops.quantized_decomposed.dequantize_per_token.default( + quantized, + scale, + zero_point, + quant_min=-2147483648, # int32 min + quant_max=2147483647, # int32 max + dtype=torch.int32, + output_dtype=torch.float32, + ) + ) + + return dequantized + + full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule() + sample_inputs = (torch.rand(size=(6, 4), dtype=torch.float32),) + + # Use higher tolerance since quantization introduces some error + self.lower_module_and_test_output( + full_per_token_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 + )