@@ -46,6 +46,9 @@ def lower_module(
4646 if dynamic_shapes is not None :
4747 compile_options ["require_dynamic_shapes" ] = True
4848
49+ # Enable downcast_64_bit by default to handle float64/int64 tensors
50+ compile_options ["downcast_64_bit" ] = True
51+
4952 edge_compile_config = EdgeCompileConfig (
5053 _skip_dim_order = False , # TODO(T182928844): Delegate dim order op to backend.
5154 )
@@ -1964,3 +1967,99 @@ def forward(self, x):
19641967 GroupNormModule (num_groups , num_channels ),
19651968 sample_inputs ,
19661969 )
1970+
1971+ def test_vulkan_backend_full_quantization_workflow (self ):
1972+ class FullQuantizationWorkflowModule (torch .nn .Module ):
1973+ def __init__ (self ):
1974+ super ().__init__ ()
1975+
1976+ def forward (self , x ):
1977+ # Step 1: Choose quantization parameters per tensor
1978+ scale , zero_point = (
1979+ torch .ops .quantized_decomposed .choose_qparams .tensor (
1980+ x ,
1981+ quant_min = - 2147483648 , # int32 min
1982+ quant_max = 2147483647 , # int32 max
1983+ eps = 1e-5 ,
1984+ dtype = torch .int32 ,
1985+ )
1986+ )
1987+
1988+ # Step 2: Quantize using the calculated parameters
1989+ quantized = torch .ops .quantized_decomposed .quantize_per_tensor .tensor (
1990+ x ,
1991+ scale ,
1992+ zero_point ,
1993+ quant_min = - 2147483648 , # int32 min
1994+ quant_max = 2147483647 , # int32 max
1995+ dtype = torch .int32 ,
1996+ )
1997+
1998+ # Step 3: Dequantize back to float
1999+ dequantized = (
2000+ torch .ops .quantized_decomposed .dequantize_per_tensor .tensor (
2001+ quantized ,
2002+ scale ,
2003+ zero_point ,
2004+ quant_min = - 2147483648 , # int32 min
2005+ quant_max = 2147483647 , # int32 max
2006+ dtype = torch .int32 ,
2007+ )
2008+ )
2009+
2010+ return dequantized
2011+
2012+ full_workflow_module = FullQuantizationWorkflowModule ()
2013+ sample_inputs = (torch .rand (size = (2 , 3 , 4 ), dtype = torch .float32 ),)
2014+
2015+ # Use higher tolerance since quantization introduces some error
2016+ self .lower_module_and_test_output (
2017+ full_workflow_module , sample_inputs , atol = 5e-3 , rtol = 5e-3
2018+ )
2019+
2020+ def test_vulkan_backend_full_per_token_quantization_workflow (self ):
2021+ class FullPerTokenQuantizationWorkflowModule (torch .nn .Module ):
2022+ def __init__ (self ):
2023+ super ().__init__ ()
2024+
2025+ def forward (self , x ):
2026+ # Step 1: Choose quantization parameters per token
2027+ scale , zero_point = (
2028+ torch .ops .quantized_decomposed .choose_qparams_per_token_asymmetric .default (
2029+ x ,
2030+ dtype = torch .int32 ,
2031+ )
2032+ )
2033+
2034+ # Step 2: Quantize using the calculated parameters per token
2035+ quantized = torch .ops .quantized_decomposed .quantize_per_token .default (
2036+ x ,
2037+ scale ,
2038+ zero_point ,
2039+ quant_min = - 2147483648 , # int32 min
2040+ quant_max = 2147483647 , # int32 max
2041+ dtype = torch .int32 ,
2042+ )
2043+
2044+ # Step 3: Dequantize back to float per token
2045+ dequantized = (
2046+ torch .ops .quantized_decomposed .dequantize_per_token .default (
2047+ quantized ,
2048+ scale ,
2049+ zero_point ,
2050+ quant_min = - 2147483648 , # int32 min
2051+ quant_max = 2147483647 , # int32 max
2052+ dtype = torch .int32 ,
2053+ output_dtype = torch .float32 ,
2054+ )
2055+ )
2056+
2057+ return dequantized
2058+
2059+ full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule ()
2060+ sample_inputs = (torch .rand (size = (6 , 4 ), dtype = torch .float32 ),)
2061+
2062+ # Use higher tolerance since quantization introduces some error
2063+ self .lower_module_and_test_output (
2064+ full_per_token_workflow_module , sample_inputs , atol = 5e-3 , rtol = 5e-3
2065+ )
0 commit comments