Skip to content

Commit a50dc92

Browse files
author
morelos
committed
[ET-VK][testing] Q/DQ/CQP op comprehensive delegate dynamic quantization testing
Pull Request resolved: #12210 # Context We need to ensure that most of the operators that were created work in tandem with each other for dynamic quantization. # Changes This creates two test cases to test the per_token and per_tensor pipeline to ensure that the whole full quantization workflow works as intended. ghstack-source-id: 295489091 @exported-using-ghexport Differential Revision: [D77746139](https://our.internmc.facebook.com/intern/diff/D77746139/)
1 parent 0800116 commit a50dc92

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def lower_module(
4646
if dynamic_shapes is not None:
4747
compile_options["require_dynamic_shapes"] = True
4848

49+
# Enable downcast_64_bit by default to handle float64/int64 tensors
50+
compile_options["downcast_64_bit"] = True
51+
4952
edge_compile_config = EdgeCompileConfig(
5053
_skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend.
5154
)
@@ -1964,3 +1967,99 @@ def forward(self, x):
19641967
GroupNormModule(num_groups, num_channels),
19651968
sample_inputs,
19661969
)
1970+
1971+
def test_vulkan_backend_full_quantization_workflow(self):
1972+
class FullQuantizationWorkflowModule(torch.nn.Module):
1973+
def __init__(self):
1974+
super().__init__()
1975+
1976+
def forward(self, x):
1977+
# Step 1: Choose quantization parameters per tensor
1978+
scale, zero_point = (
1979+
torch.ops.quantized_decomposed.choose_qparams.tensor(
1980+
x,
1981+
quant_min=-2147483648, # int32 min
1982+
quant_max=2147483647, # int32 max
1983+
eps=1e-5,
1984+
dtype=torch.int32,
1985+
)
1986+
)
1987+
1988+
# Step 2: Quantize using the calculated parameters
1989+
quantized = torch.ops.quantized_decomposed.quantize_per_tensor.tensor(
1990+
x,
1991+
scale,
1992+
zero_point,
1993+
quant_min=-2147483648, # int32 min
1994+
quant_max=2147483647, # int32 max
1995+
dtype=torch.int32,
1996+
)
1997+
1998+
# Step 3: Dequantize back to float
1999+
dequantized = (
2000+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor(
2001+
quantized,
2002+
scale,
2003+
zero_point,
2004+
quant_min=-2147483648, # int32 min
2005+
quant_max=2147483647, # int32 max
2006+
dtype=torch.int32,
2007+
)
2008+
)
2009+
2010+
return dequantized
2011+
2012+
full_workflow_module = FullQuantizationWorkflowModule()
2013+
sample_inputs = (torch.rand(size=(2, 3, 4), dtype=torch.float32),)
2014+
2015+
# Use higher tolerance since quantization introduces some error
2016+
self.lower_module_and_test_output(
2017+
full_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3
2018+
)
2019+
2020+
def test_vulkan_backend_full_per_token_quantization_workflow(self):
2021+
class FullPerTokenQuantizationWorkflowModule(torch.nn.Module):
2022+
def __init__(self):
2023+
super().__init__()
2024+
2025+
def forward(self, x):
2026+
# Step 1: Choose quantization parameters per token
2027+
scale, zero_point = (
2028+
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
2029+
x,
2030+
dtype=torch.int32,
2031+
)
2032+
)
2033+
2034+
# Step 2: Quantize using the calculated parameters per token
2035+
quantized = torch.ops.quantized_decomposed.quantize_per_token.default(
2036+
x,
2037+
scale,
2038+
zero_point,
2039+
quant_min=-2147483648, # int32 min
2040+
quant_max=2147483647, # int32 max
2041+
dtype=torch.int32,
2042+
)
2043+
2044+
# Step 3: Dequantize back to float per token
2045+
dequantized = (
2046+
torch.ops.quantized_decomposed.dequantize_per_token.default(
2047+
quantized,
2048+
scale,
2049+
zero_point,
2050+
quant_min=-2147483648, # int32 min
2051+
quant_max=2147483647, # int32 max
2052+
dtype=torch.int32,
2053+
output_dtype=torch.float32,
2054+
)
2055+
)
2056+
2057+
return dequantized
2058+
2059+
full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule()
2060+
sample_inputs = (torch.rand(size=(6, 4), dtype=torch.float32),)
2061+
2062+
# Use higher tolerance since quantization introduces some error
2063+
self.lower_module_and_test_output(
2064+
full_per_token_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3
2065+
)

0 commit comments

Comments
 (0)