Skip to content

Commit e5605ee

Browse files
author
morelos
committed
[ET-VK][testing] Q/DQ/CQP op comprehensive delegate dynamic quantization testing
# Context We need to ensure that most of the operators that were created work in tandem with each other for dynamic quantization. # Changes This creates two test cases to test the per_token and per_tensor pipeline to ensure that the whole full quantization workflow works as intended. Differential Revision: [D77746139](https://our.internmc.facebook.com/intern/diff/D77746139/) [ghstack-poisoned]
1 parent 10512c5 commit e5605ee

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,3 +1964,87 @@ def forward(self, x):
19641964
GroupNormModule(num_groups, num_channels),
19651965
sample_inputs,
19661966
)
1967+
1968+
def test_vulkan_backend_full_quantization_workflow(self):
1969+
class FullQuantizationWorkflowModule(torch.nn.Module):
1970+
def __init__(self):
1971+
super().__init__()
1972+
1973+
def forward(self, x):
1974+
# Step 1: Choose quantization parameters per tensor
1975+
scale, zero_point = torch.ops.quantized_decomposed.choose_qparams.tensor(
1976+
x,
1977+
quant_min=-2147483648, # int32 min
1978+
quant_max=2147483647, # int32 max
1979+
eps=1e-5,
1980+
dtype=torch.int32,
1981+
)
1982+
1983+
# Step 2: Quantize using the calculated parameters
1984+
quantized = torch.ops.quantized_decomposed.quantize_per_tensor.tensor(
1985+
x,
1986+
scale,
1987+
zero_point,
1988+
quant_min=-2147483648, # int32 min
1989+
quant_max=2147483647, # int32 max
1990+
dtype=torch.int32,
1991+
)
1992+
1993+
# Step 3: Dequantize back to float
1994+
dequantized = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor(
1995+
quantized,
1996+
scale,
1997+
zero_point,
1998+
quant_min=-2147483648, # int32 min
1999+
quant_max=2147483647, # int32 max
2000+
dtype=torch.int32,
2001+
)
2002+
2003+
return dequantized
2004+
2005+
full_workflow_module = FullQuantizationWorkflowModule()
2006+
sample_inputs = (torch.rand(size=(2, 3, 4), dtype=torch.float32),)
2007+
2008+
# Use higher tolerance since quantization introduces some error
2009+
self.lower_module_and_test_output(full_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3)
2010+
2011+
def test_vulkan_backend_full_per_token_quantization_workflow(self):
2012+
class FullPerTokenQuantizationWorkflowModule(torch.nn.Module):
2013+
def __init__(self):
2014+
super().__init__()
2015+
2016+
def forward(self, x):
2017+
# Step 1: Choose quantization parameters per token
2018+
scale, zero_point = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
2019+
x,
2020+
dtype=torch.int32,
2021+
)
2022+
2023+
# Step 2: Quantize using the calculated parameters per token
2024+
quantized = torch.ops.quantized_decomposed.quantize_per_token.default(
2025+
x,
2026+
scale,
2027+
zero_point,
2028+
quant_min=-2147483648, # int32 min
2029+
quant_max=2147483647, # int32 max
2030+
dtype=torch.int32,
2031+
)
2032+
2033+
# Step 3: Dequantize back to float per token
2034+
dequantized = torch.ops.quantized_decomposed.dequantize_per_token.default(
2035+
quantized,
2036+
scale,
2037+
zero_point,
2038+
quant_min=-2147483648, # int32 min
2039+
quant_max=2147483647, # int32 max
2040+
dtype=torch.int32,
2041+
output_dtype=torch.float32,
2042+
)
2043+
2044+
return dequantized
2045+
2046+
full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule()
2047+
sample_inputs = (torch.rand(size=(6, 4), dtype=torch.float32),)
2048+
2049+
# Use higher tolerance since quantization introduces some error
2050+
self.lower_module_and_test_output(full_per_token_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3)

0 commit comments

Comments
 (0)