diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 6f8fdead5b7..1c2dc43dbb5 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -745,21 +745,19 @@ void test_vulkan_dequantize_per_tensor_tensor( } // Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_channel( +void test_vulkan_dequantize_per_tensor_tensor( const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, + float scale, + int zero_point, int64_t quant_min, int64_t quant_max, at::ScalarType dtype, at::ScalarType out_dtype) { // Test with buffer storage - test_vulkan_dequantize_per_channel_impl( + test_vulkan_dequantize_per_tensor_tensor_impl( input_sizes, - scales, - zero_points, - axis, + scale, + zero_point, quant_min, quant_max, dtype, @@ -774,11 +772,10 @@ void test_vulkan_dequantize_per_channel( } // Test with texture storage - test_vulkan_dequantize_per_channel_impl( + test_vulkan_dequantize_per_tensor_tensor_impl( input_sizes, - scales, - zero_points, - axis, + scale, + zero_point, quant_min, quant_max, dtype, diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 04adf183e55..926452dd388 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1964,3 +1964,99 @@ def forward(self, x): GroupNormModule(num_groups, num_channels), sample_inputs, ) + + def test_vulkan_backend_full_quantization_workflow(self): + class FullQuantizationWorkflowModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # Step 1: Choose quantization parameters per tensor + scale, zero_point = ( + torch.ops.quantized_decomposed.choose_qparams.tensor( + x, + quant_min=-2147483648, # int32 min + quant_max=2147483647, # int32 max + eps=1e-5, + dtype=torch.int32, + ) + ) + + # Step 2: Quantize using the calculated parameters + quantized = torch.ops.quantized_decomposed.quantize_per_tensor.tensor( + x, + scale, + zero_point, + quant_min=-2147483648, # int32 min + quant_max=2147483647, # int32 max + dtype=torch.int32, + ) + + # Step 3: Dequantize back to float + dequantized = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor( + quantized, + scale, + zero_point, + quant_min=-2147483648, # int32 min + quant_max=2147483647, # int32 max + dtype=torch.int32, + ) + ) + + return dequantized + + full_workflow_module = FullQuantizationWorkflowModule() + sample_inputs = (torch.rand(size=(2, 3, 4), dtype=torch.float32),) + + # Use higher tolerance since quantization introduces some error + self.lower_module_and_test_output( + full_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 + ) + + def test_vulkan_backend_full_per_token_quantization_workflow(self): + class FullPerTokenQuantizationWorkflowModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # Step 1: Choose quantization parameters per token + scale, zero_point = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( + x, + dtype=torch.int32, + ) + ) + + # Step 2: Quantize using the calculated parameters per token + quantized = torch.ops.quantized_decomposed.quantize_per_token.default( + x, + scale, + zero_point, + quant_min=-2147483648, # int32 min + quant_max=2147483647, # int32 max + dtype=torch.int32, + ) + + # Step 3: Dequantize back to float per token + dequantized = ( + torch.ops.quantized_decomposed.dequantize_per_token.default( + quantized, + scale, + zero_point, + quant_min=-2147483648, # int32 min + quant_max=2147483647, # int32 max + dtype=torch.int32, + output_dtype=torch.float32, + ) + ) + + return dequantized + + full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule() + sample_inputs = (torch.rand(size=(6, 4), dtype=torch.float32),) + + # Use higher tolerance since quantization introduces some error + self.lower_module_and_test_output( + full_per_token_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 + )