@@ -1964,3 +1964,87 @@ def forward(self, x):
19641964 GroupNormModule (num_groups , num_channels ),
19651965 sample_inputs ,
19661966 )
1967+
1968+ def test_vulkan_backend_full_quantization_workflow (self ):
1969+ class FullQuantizationWorkflowModule (torch .nn .Module ):
1970+ def __init__ (self ):
1971+ super ().__init__ ()
1972+
1973+ def forward (self , x ):
1974+ # Step 1: Choose quantization parameters per tensor
1975+ scale , zero_point = torch .ops .quantized_decomposed .choose_qparams .tensor (
1976+ x ,
1977+ quant_min = - 2147483648 , # int32 min
1978+ quant_max = 2147483647 , # int32 max
1979+ eps = 1e-5 ,
1980+ dtype = torch .int32 ,
1981+ )
1982+
1983+ # Step 2: Quantize using the calculated parameters
1984+ quantized = torch .ops .quantized_decomposed .quantize_per_tensor .tensor (
1985+ x ,
1986+ scale ,
1987+ zero_point ,
1988+ quant_min = - 2147483648 , # int32 min
1989+ quant_max = 2147483647 , # int32 max
1990+ dtype = torch .int32 ,
1991+ )
1992+
1993+ # Step 3: Dequantize back to float
1994+ dequantized = torch .ops .quantized_decomposed .dequantize_per_tensor .tensor (
1995+ quantized ,
1996+ scale ,
1997+ zero_point ,
1998+ quant_min = - 2147483648 , # int32 min
1999+ quant_max = 2147483647 , # int32 max
2000+ dtype = torch .int32 ,
2001+ )
2002+
2003+ return dequantized
2004+
2005+ full_workflow_module = FullQuantizationWorkflowModule ()
2006+ sample_inputs = (torch .rand (size = (2 , 3 , 4 ), dtype = torch .float32 ),)
2007+
2008+ # Use higher tolerance since quantization introduces some error
2009+ self .lower_module_and_test_output (full_workflow_module , sample_inputs , atol = 5e-3 , rtol = 5e-3 )
2010+
2011+ def test_vulkan_backend_full_per_token_quantization_workflow (self ):
2012+ class FullPerTokenQuantizationWorkflowModule (torch .nn .Module ):
2013+ def __init__ (self ):
2014+ super ().__init__ ()
2015+
2016+ def forward (self , x ):
2017+ # Step 1: Choose quantization parameters per token
2018+ scale , zero_point = torch .ops .quantized_decomposed .choose_qparams_per_token_asymmetric .default (
2019+ x ,
2020+ dtype = torch .int32 ,
2021+ )
2022+
2023+ # Step 2: Quantize using the calculated parameters per token
2024+ quantized = torch .ops .quantized_decomposed .quantize_per_token .default (
2025+ x ,
2026+ scale ,
2027+ zero_point ,
2028+ quant_min = - 2147483648 , # int32 min
2029+ quant_max = 2147483647 , # int32 max
2030+ dtype = torch .int32 ,
2031+ )
2032+
2033+ # Step 3: Dequantize back to float per token
2034+ dequantized = torch .ops .quantized_decomposed .dequantize_per_token .default (
2035+ quantized ,
2036+ scale ,
2037+ zero_point ,
2038+ quant_min = - 2147483648 , # int32 min
2039+ quant_max = 2147483647 , # int32 max
2040+ dtype = torch .int32 ,
2041+ output_dtype = torch .float32 ,
2042+ )
2043+
2044+ return dequantized
2045+
2046+ full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule ()
2047+ sample_inputs = (torch .rand (size = (6 , 4 ), dtype = torch .float32 ),)
2048+
2049+ # Use higher tolerance since quantization introduces some error
2050+ self .lower_module_and_test_output (full_per_token_workflow_module , sample_inputs , atol = 5e-3 , rtol = 5e-3 )
0 commit comments