@@ -1964,3 +1964,99 @@ def forward(self, x):
19641964 GroupNormModule (num_groups , num_channels ),
19651965 sample_inputs ,
19661966 )
1967+
1968+ def test_vulkan_backend_full_quantization_workflow (self ):
1969+ class FullQuantizationWorkflowModule (torch .nn .Module ):
1970+ def __init__ (self ):
1971+ super ().__init__ ()
1972+
1973+ def forward (self , x ):
1974+ # Step 1: Choose quantization parameters per tensor
1975+ scale , zero_point = (
1976+ torch .ops .quantized_decomposed .choose_qparams .tensor (
1977+ x ,
1978+ quant_min = - 2147483648 , # int32 min
1979+ quant_max = 2147483647 , # int32 max
1980+ eps = 1e-5 ,
1981+ dtype = torch .int32 ,
1982+ )
1983+ )
1984+
1985+ # Step 2: Quantize using the calculated parameters
1986+ quantized = torch .ops .quantized_decomposed .quantize_per_tensor .tensor (
1987+ x ,
1988+ scale ,
1989+ zero_point ,
1990+ quant_min = - 2147483648 , # int32 min
1991+ quant_max = 2147483647 , # int32 max
1992+ dtype = torch .int32 ,
1993+ )
1994+
1995+ # Step 3: Dequantize back to float
1996+ dequantized = (
1997+ torch .ops .quantized_decomposed .dequantize_per_tensor .tensor (
1998+ quantized ,
1999+ scale ,
2000+ zero_point ,
2001+ quant_min = - 2147483648 , # int32 min
2002+ quant_max = 2147483647 , # int32 max
2003+ dtype = torch .int32 ,
2004+ )
2005+ )
2006+
2007+ return dequantized
2008+
2009+ full_workflow_module = FullQuantizationWorkflowModule ()
2010+ sample_inputs = (torch .rand (size = (2 , 3 , 4 ), dtype = torch .float32 ),)
2011+
2012+ # Use higher tolerance since quantization introduces some error
2013+ self .lower_module_and_test_output (
2014+ full_workflow_module , sample_inputs , atol = 5e-3 , rtol = 5e-3
2015+ )
2016+
2017+ def test_vulkan_backend_full_per_token_quantization_workflow (self ):
2018+ class FullPerTokenQuantizationWorkflowModule (torch .nn .Module ):
2019+ def __init__ (self ):
2020+ super ().__init__ ()
2021+
2022+ def forward (self , x ):
2023+ # Step 1: Choose quantization parameters per token
2024+ scale , zero_point = (
2025+ torch .ops .quantized_decomposed .choose_qparams_per_token_asymmetric .default (
2026+ x ,
2027+ dtype = torch .int32 ,
2028+ )
2029+ )
2030+
2031+ # Step 2: Quantize using the calculated parameters per token
2032+ quantized = torch .ops .quantized_decomposed .quantize_per_token .default (
2033+ x ,
2034+ scale ,
2035+ zero_point ,
2036+ quant_min = - 2147483648 , # int32 min
2037+ quant_max = 2147483647 , # int32 max
2038+ dtype = torch .int32 ,
2039+ )
2040+
2041+ # Step 3: Dequantize back to float per token
2042+ dequantized = (
2043+ torch .ops .quantized_decomposed .dequantize_per_token .default (
2044+ quantized ,
2045+ scale ,
2046+ zero_point ,
2047+ quant_min = - 2147483648 , # int32 min
2048+ quant_max = 2147483647 , # int32 max
2049+ dtype = torch .int32 ,
2050+ output_dtype = torch .float32 ,
2051+ )
2052+ )
2053+
2054+ return dequantized
2055+
2056+ full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule ()
2057+ sample_inputs = (torch .rand (size = (6 , 4 ), dtype = torch .float32 ),)
2058+
2059+ # Use higher tolerance since quantization introduces some error
2060+ self .lower_module_and_test_output (
2061+ full_per_token_workflow_module , sample_inputs , atol = 5e-3 , rtol = 5e-3
2062+ )
0 commit comments