diff --git a/tests/mock_observer.py b/tests/mock_observer.py new file mode 100644 index 00000000..4563061c --- /dev/null +++ b/tests/mock_observer.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple +from weakref import ref + +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization.utils import ( + calculate_qparams, + generate_gparam, + strategy_cdiv, +) + + +class MockMinMaxObserver(torch.nn.Module): + def __init__(self, base_name: str, args: QuantizationArgs, module: torch.nn.Module): + super().__init__() + self.parent = ref(module) + self.base_name = base_name + self.args = args + + # used for testing + self.min_vals = None + self.max_vals = None + + def get_min_max(self, observed: torch.Tensor): + min_vals = torch.amin(observed, dim=(0, -1)) + max_vals = torch.amax(observed, dim=(0, -1)) + + return min_vals, max_vals + + def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + observed = flatten_for_quantization(observed, self.base_name, self.args) + + self.min_vals, self.max_vals = self.get_min_max(observed) + + scales, zero_points = calculate_qparams( + min_vals=self.min_vals, + max_vals=self.max_vals, + quantization_args=self.args, + global_scale=getattr(self.parent(), f"{self.base_name}_global_scale", None), + ) + + return scales, zero_points + + def get_global_scale(self, observed: torch.Tensor): + observed = observed.reshape((1, 1, -1)) # per tensor reshape + min_vals, max_vals = self.get_min_max(observed) + global_scale = generate_gparam(min_vals, max_vals) + + return global_scale + + +def flatten_for_quantization( + value: torch.Tensor, base_name: str, args: QuantizationArgs +) -> torch.Tensor: + if base_name == "weight": + return flatten_weight_for_quantization(value, args) + elif base_name in ("input", "output"): + return flatten_activation_for_quantization(value, args) + elif base_name in ("q", "k", "v"): + return flatten_attention_for_quantization(value, args) + else: + raise ValueError(f"Unknown quantization base name: {base_name}") + + +def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs): + if args.strategy == QuantizationStrategy.TENSOR: + # (1, 1, num_weight_elems) + return value.reshape((1, 1, -1)) + + if args.strategy == QuantizationStrategy.TOKEN: + raise ValueError("Token quantization cannot be applied to weights") + + if args.strategy == QuantizationStrategy.CHANNEL: + # (1, num_rows, 1, num_cols) + return value.unsqueeze(-2).unsqueeze(0) + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + # (1, num_rows, num_groups, group_size) + return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0) + + if args.strategy == QuantizationStrategy.BLOCK: + # (1, num_block_rows, num_block_cols, block_width * block_height) + block_height, block_width = args.block_structure + num_rows, num_cols = value.shape + num_block_rows = strategy_cdiv(num_rows, block_height, args.strategy) + num_block_cols = strategy_cdiv(num_cols, block_width, args.strategy) + return ( + value.reshape( + num_block_rows, + block_height, + num_block_cols, + block_width, + ) + .transpose(1, 2) + .flatten(-2, -1) + .unsqueeze(0) + ) + + assert False, f"Unknown strategy {args.strategy}" + + +def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs): + if args.strategy == QuantizationStrategy.TENSOR: + # (batch_size * seq_len, 1, hidden_dim) + return value.reshape((-1, 1, value.size(-1))) + + if args.strategy == QuantizationStrategy.TOKEN: + # (batch_size, seq_len, hidden_dim) + # warning: token quantization uses `compute_dynamic_scales_and_zp` + return value.flatten(2, -1) + + if args.strategy == QuantizationStrategy.CHANNEL: + raise ValueError("Channel quantization cannot be applied to activations") + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + # (batch_size * seq_len, num_groups, group_size) + # warning: group activation quantization uses compute_dynamic_scales_and_zp + return value.flatten(0, 1).unflatten(-1, (-1, args.group_size)) + + if args.strategy == QuantizationStrategy.BLOCK: + raise ValueError("Block quantization cannot be applied to activations") + + assert False, f"Unknown strategy {args.strategy}" + + +def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs): + if args.strategy == QuantizationStrategy.TENSOR: + # (batch_size, seq_len, num_heads, head_dim) + # (batch_size * seq_len, 1, num_heads * head_dim) + return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2) + + if args.strategy == QuantizationStrategy.TOKEN: + raise ValueError("Token quantization cannot be applied to attention") + + if args.strategy == QuantizationStrategy.CHANNEL: + raise ValueError("Channel quantization cannot be applied to attention") + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + raise ValueError("Group quantization cannot be applied to attention") + + if args.strategy == QuantizationStrategy.BLOCK: + raise ValueError("Block quantization cannot be applied to attention") + + assert False, f"Unknown strategy {args.strategy}" diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py new file mode 100644 index 00000000..45ba602c --- /dev/null +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -0,0 +1,351 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.quantization import ( + QuantizationScheme, + forward_quantize, + initialize_module_for_quantization, + initialize_qparams, +) +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_config import QuantizationStatus +from tests.mock_observer import MockMinMaxObserver + + +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", # equivalent to token + ), + torch.tensor([0.0]), + torch.tensor([23.0]), + torch.tensor( + [ + [0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250], + [6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500], + [12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.85, + ), + # token is not supported + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="channel", + ), + torch.tensor([[0], [6], [12], [18]]), + torch.tensor([[5], [11], [17], [23]]), + torch.tensor( + [ + [0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875], + [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], + [11.3125, 13.6250, 13.6250, 15.8750, 15.8750, 15.8750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.45, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=3, + ), + torch.tensor([[0, 3], [6, 9], [12, 15], [18, 21]]), + torch.tensor([[2, 5], [8, 11], [14, 17], [20, 23]]), + torch.tensor( + [ + [0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875], + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], + [11.1875, 13.0625, 13.0625, 15.8750, 15.8750, 15.8750], + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], + ], + ), + 0.45, + ), + ( + QuantizationArgs( + num_bits=4, + type="float", # tensor group requires FP4 + symmetric=True, + strategy="tensor_group", # requires float4 + group_size=3, + ), + torch.tensor([[0, 3], [6, 9], [12, 15], [18, 21]]), + torch.tensor([[2, 5], [8, 11], [14, 17], [20, 23]]), + torch.tensor( + [ + [0.0000, 1.0234, 2.0469, 3.2812, 3.2812, 4.9375], + [5.4688, 8.1875, 8.1875, 10.6875, 10.6875, 10.6875], + [9.8750, 14.7500, 14.7500, 16.3750, 16.3750, 16.3750], + [19.7500, 19.7500, 19.7500, 23.0000, 23.0000, 23.0000], + ], + ), + 1.1, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="block", + block_structure=[2, 3], + ), + torch.tensor([[0, 3], [12, 15]]), + torch.tensor([[8, 11], [20, 23]]), + torch.tensor( + [ + [0.0000, 1.0703, 2.1406, 2.9375, 4.4062, 4.4062], + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], + [10.6875, 13.3750, 13.3750, 15.3125, 15.3125, 18.3750], + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], + ], + ), + 0.5, + ), + ], +) +def test_static_weight_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + weight = tensor([[ 0, 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17], + [18, 19, 20, 21, 22, 23]]) + """ + # set up weight + input_size, output_size = 6, 4 + linear = torch.nn.Linear(input_size, output_size, bias=False) + linear.weight.data = torch.arange( + input_size * output_size, dtype=torch.bfloat16 + ).reshape(output_size, input_size) + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], weights=args) + initialize_module_for_quantization(linear, scheme) + assert getattr(linear, "quantization_scheme") is scheme + linear.weight_observer = MockMinMaxObserver("weight", args, linear) + + # calibrate_global_scale + if hasattr(linear, "weight_global_scale"): + global_scale = linear.weight_observer.get_global_scale(linear.weight) + linear.weight_global_scale.data = global_scale + + # calibrate quantization parameters + scale, zero_point = linear.weight_observer(linear.weight) + linear.weight_scale.data = scale + linear.weight_zero_point.data = zero_point + assert torch.equal(linear.weight_observer.min_vals, exp_min_val) + assert torch.equal(linear.weight_observer.max_vals, exp_max_val) + + # forward pass + input = torch.eye(input_size, dtype=torch.bfloat16) + output = linear(input) + + assert torch.allclose(output.T, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output.T, linear.weight) <= exp_loss + + +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", + ), + torch.tensor([0.0]), + torch.tensor([11.0]), + torch.tensor( + [ + [ + [0.0000, 1.4688, 1.4688, 2.9375, 4.4062, 4.4062], + [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], + ] + ] + ), + 0.2, + ), + # static token is not supported + # channel is not supported + # group is not supported + ( + QuantizationArgs( + num_bits=4, + type="float", # must be fp4 + symmetric=True, + strategy="tensor_group", + dynamic="local", + group_size=3, + ), + None, + None, + torch.tensor( + [ + [ + [0.0000, 0.9844, 1.9688, 3.4062, 3.4062, 5.1250], + [5.2500, 7.8750, 7.8750, 7.3438, 11.0000, 11.0000], + ] + ] + ), + 0.5, + ), + # block is not supported + # head is not supported + ], +) +def test_static_activation_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + input = tensor([[ 0, 1, 2, 3, 4, 5] + [ 6, 7, 8, 9, 10, 11]]) + """ + # set up activation (and identity weight) + batch_size, seq_len, input_size = 1, 2, 6 + input = torch.arange( + (batch_size * seq_len * input_size), dtype=torch.bfloat16 + ).reshape((batch_size, seq_len, input_size)) + linear = torch.nn.Linear(input_size, input_size, bias=False) + linear.weight.data = torch.eye(input_size, dtype=torch.bfloat16) + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], input_activations=args) + initialize_module_for_quantization(linear, scheme) + assert getattr(linear, "quantization_scheme") is scheme + linear.input_observer = MockMinMaxObserver("input", args, linear) + + # calibrate quantization parameters + def calibrate_input_hook(_, args): + if hasattr(linear, "input_global_scale"): + global_scale = linear.input_observer.get_global_scale(args[0]) + linear.input_global_scale.data = global_scale + + if linear.quantization_scheme.input_activations.dynamic is False: + scale, zero_point = linear.input_observer(args[0]) + linear.input_scale.data = scale + linear.input_zero_point.data = zero_point + + linear.register_forward_pre_hook(calibrate_input_hook) + + # calibration forward pass + output = linear(input) + + # check calibration + if exp_min_val is not None: + assert torch.equal(linear.input_observer.min_vals, exp_min_val) + if exp_max_val is not None: + assert torch.equal(linear.input_observer.max_vals, exp_max_val) + + # check forward pass + assert torch.allclose(output, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output, input) <= exp_loss + + +class MockAttention(torch.nn.Module): + pass + + +@pytest.mark.filterwarnings("ignore::UserWarning") +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", + ), + torch.tensor([0.0]), + torch.tensor([11.0]), + torch.tensor( + [ + [ + [[0.0000, 1.4688, 1.4688], [2.9375, 4.4062, 4.4062]], + [[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]], + ] + ] + ), + 0.19, + ), + # static token is not supported + # channel is not supported + # group is not supported + # tensor group is not supported + # block is not supported + ], +) +def test_static_attention_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + input = tensor([[[[ 0., 1., 2.], + [ 3., 4., 5.]], + + [[ 6., 7., 8.], + [ 9., 10., 11.]]]]) + """ + # set up activation (and identity weight) + batch_size, seq_len, num_heads, head_dim = 1, 2, 2, 3 + input = torch.arange( + (batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16 + ).reshape((batch_size, seq_len, num_heads, head_dim)) + attention = MockAttention() + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], input_activations=args) + initialize_qparams( + attention, "k", args, (num_heads, head_dim), observed_dtype=torch.bfloat16 + ) + attention.quantization_scheme = scheme + attention.quantization_status = QuantizationStatus.INITIALIZED + attention.k_observer = MockMinMaxObserver("k", args, attention) + + # calibrate quantization parameters + if scheme.input_activations.dynamic is False: + scale, zero_point = attention.k_observer(input) + attention.k_scale.data = scale + attention.k_zero_point.data = zero_point + + # calibration forward pass + output = forward_quantize(attention, input, "k", scheme.input_activations) + + # check calibration + if exp_min_val is not None: + assert torch.equal(attention.k_observer.min_vals, exp_min_val) + if exp_max_val is not None: + assert torch.equal(attention.k_observer.max_vals, exp_max_val) + + # check forward pass + assert torch.allclose(output, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output, input) <= exp_loss