Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from model_compression_toolkit.logger import Logger

from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR


@dataclass
class ManualBitWidthSelection:
Expand Down Expand Up @@ -221,9 +223,10 @@ def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict:
if isinstance(attr_str, str) and isinstance(manual_bit_width_selection.attr, str):
if attr_str.find(manual_bit_width_selection.attr) != -1:
attr.append(attr_str)
elif isinstance(attr_str, int) and isinstance(manual_bit_width_selection.attr, int):
if attr_str == manual_bit_width_selection.attr:
attr.append(attr_str)
# this is a positional attribute, so it needs to be handled separately.
# Search manual_bit_width_selection's attribute that contain the POS_ATTR string.
elif isinstance(attr_str, int) and POS_ATTR in manual_bit_width_selection.attr:
attr.append(POS_ATTR)
if len(attr) == 0:
Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')

Expand Down
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add fail tests: try to manually select an unsupported bitwidth.

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import model_compression_toolkit as mct
import torch
from torch.nn import Conv2d
from model_compression_toolkit.target_platform_capabilities.constants import BIAS, PYTORCH_KERNEL
from torch import add, sub

from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
from model_compression_toolkit.target_platform_capabilities.constants import BIAS, PYTORCH_KERNEL, POS_ATTR
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS
from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter
from model_compression_toolkit.core import CoreConfig
Expand Down Expand Up @@ -94,52 +97,142 @@ def generate_tpc_local(default_config, base_config, mixed_precision_cfg_list):
return generated_tpc


def get_tpc(kernel_n_bits, bias_n_bits):
base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits)
tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list)
return tpc
def generate_tpc_pos_attr_local(default_config):
default_configuration_options = schema.QuantizationConfigOptions(
quantization_configurations=tuple([default_config]))

const_config_input16 = default_config.clone_and_edit(
supported_input_activation_n_bits=(8, 16))
const_config_input16_output16 = const_config_input16.clone_and_edit(
activation_n_bits=16, signedness=schema.Signedness.SIGNED)

# define a quantization config to quantize the positional weights into 16 bit (for layers where there is a
# positional weight attribute).
positional_weight_16_attr_config = schema.AttributeQuantizationConfig(
weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
weights_n_bits=16,
weights_per_channel_threshold=False,
enable_weights_quantization=True,
lut_values_bitwidth=None)

# define a quantization config to quantize the positional weights into 8 bit (for layers where there is a
# positional weight attribute).
positional_weight_8_attr_config = schema.AttributeQuantizationConfig(
weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
weights_n_bits=8,
weights_per_channel_threshold=False,
enable_weights_quantization=True,
lut_values_bitwidth=None)

const_config_input16_positional_weight16 = const_config_input16.clone_and_edit(
attr_weights_configs_mapping={POS_ATTR: positional_weight_16_attr_config})

const_config_input16_positional_weight8 = const_config_input16.clone_and_edit(
attr_weights_configs_mapping={POS_ATTR: positional_weight_8_attr_config})
const_configuration_options_inout16 = (
schema.QuantizationConfigOptions(quantization_configurations=tuple([
const_config_input16,
const_config_input16_positional_weight8,
const_config_input16_positional_weight16]),
base_config=const_config_input16))

# define a quantization config to quantize the positional weights into 2 bit (for layers where there is a
# positional weight attribute).
positional_weight_2_attr_config = schema.AttributeQuantizationConfig(
weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
weights_n_bits=2,
weights_per_channel_threshold=False,
enable_weights_quantization=True,
lut_values_bitwidth=None)

const_config_input16_positional_weight2 = const_config_input16.clone_and_edit(
attr_weights_configs_mapping={POS_ATTR: positional_weight_2_attr_config})
const_configuration_options_inout_2 = (
schema.QuantizationConfigOptions(quantization_configurations=tuple([
const_config_input16,
const_config_input16_positional_weight2]),
base_config=const_config_input16))

operator_set = []

add = schema.OperatorsSet(name=schema.OperatorSetNames.ADD, qc_options=const_configuration_options_inout16)
sub = schema.OperatorsSet(name=schema.OperatorSetNames.SUB, qc_options=const_configuration_options_inout_2)
operator_set.extend([add, sub])

generated_tpc = schema.TargetPlatformCapabilities(
default_qco=default_configuration_options,
operator_set=tuple(operator_set))

return generated_tpc


def representative_data_gen(shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=1):
for _ in range(num_iter):
yield [torch.randn(batch_size, *shape)] * num_inputs


def get_float_model():
class BaseModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
self.relu = torch.nn.ReLU()
class BaseModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.relu(x)
return x

return BaseModel()


def get_float_model_with_constants():
class BaseModel(torch.nn.Module):
def __init__(self):
super().__init__()
a = torch.rand(8)
b = torch.rand(8)
self.a = to_torch_tensor(a)
self.b = to_torch_tensor(b)

def forward(self, x):
x = torch.add(x, self.a)
x = torch.sub(self.b, x)
return x

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.relu(x)
return x
return BaseModel()
return BaseModel()


class TestManualWeightsBitwidthSelectionByLayerType:
def get_float_model(self):
return get_float_model()

def get_tpc(self, kernel_n_bits, bias_n_bits):
base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits)
tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list)
return tpc

# (LayerType, bit width, attribute, kernel_n_bits, bias_n_bits)
test_input_1 = (NodeTypeFilter(Conv2d), 16, PYTORCH_KERNEL, 16, None)
test_input_2 = (NodeTypeFilter(Conv2d), [2], [PYTORCH_KERNEL], 2, None)

test_expected_1 = ([Conv2d], [16])
test_expected_2 = ([Conv2d], [2])

@pytest.mark.parametrize(("inputs", "expected"), [
(test_input_1, test_expected_1),
(test_input_2, test_expected_2),
])

def test_manual_weights_bitwidth_selection(self, inputs, expected):
float_model = get_float_model()
float_model = self.get_float_model()

target_platform_cap = self.get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4])

target_platform_cap = get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4])

core_config = CoreConfig()
core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2])

quantized_model, _ = mct.ptq.pytorch_post_training_quantization(
in_module=float_model,
representative_data_gen=representative_data_gen,
Expand All @@ -157,12 +250,20 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected):
attrs = [attrs]

for bitwidth, attr in zip(expected_bitwidths, attrs):

if layer.weights_quantizers.get(attr) is not None:
assert layer.weights_quantizers.get(attr).num_bits == bitwidth


class TestManualWeightsBitwidthSelectionByLayerName:
def get_float_model(self):
return get_float_model()

def get_tpc(self, kernel_n_bits, bias_n_bits):
base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits)
tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list)
return tpc

# (LayerName, bit width, attribute, kernel_n_bits, bias_n_bits)
test_input_1 = (NodeNameFilter("conv1"), 16, PYTORCH_KERNEL, 16, None)
test_input_2 = (NodeNameFilter("conv1"), [2], [PYTORCH_KERNEL], 2, None)
Expand All @@ -171,22 +272,21 @@ class TestManualWeightsBitwidthSelectionByLayerName:
test_expected_1 = (["conv1"], [16])
test_expected_2 = (["conv1"], [2])
test_expected_3 = (["conv1", "conv1"], [4, 16])

@pytest.mark.parametrize(("inputs", "expected"), [
(test_input_1, test_expected_1),
(test_input_2, test_expected_2),
(test_input_3, test_expected_3),
])

def test_manual_weights_bitwidth_selection(self, inputs, expected):

float_model = get_float_model()
float_model = self.get_float_model()

target_platform_cap = self.get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4])

target_platform_cap = get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4])

core_config = CoreConfig()
core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2])

quantized_model, _ = mct.ptq.pytorch_post_training_quantization(
in_module=float_model,
representative_data_gen=representative_data_gen,
Expand All @@ -207,7 +307,54 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected):
else:
for attr in attrs:
if layer.weights_quantizers.get(attr) is not None:
if attr == PYTORCH_KERNEL:
assert layer.weights_quantizers.get(attr).num_bits == kernel_weights_n_bits
elif attr == BIAS:
assert layer.weights_quantizers.get(attr).num_bits == bias_weights_n_bits
if attr == PYTORCH_KERNEL:
assert layer.weights_quantizers.get(attr).num_bits == kernel_weights_n_bits
elif attr == BIAS:
assert layer.weights_quantizers.get(attr).num_bits == bias_weights_n_bits


class TestManualPositionalAttrWeightsBitwidthSelectionByLayerType(TestManualWeightsBitwidthSelectionByLayerType):
def get_float_model(self):
return get_float_model_with_constants()

def get_tpc(self, kernel_n_bits, bias_n_bits):
_, _, default_config = get_op_qco(kernel_n_bits, bias_n_bits)
tpc = generate_tpc_pos_attr_local(default_config)
return tpc

# (LayerType, bit width, attribute)
test_input_1 = (NodeTypeFilter(add), 16, POS_ATTR, 8, 8)
test_input_2 = (NodeTypeFilter(sub), [2], [POS_ATTR], 8, 8)

test_expected_1 = ([add], [16])
test_expected_2 = ([sub], [2])

@pytest.mark.parametrize(("inputs", "expected"), [
(test_input_1, test_expected_1),
(test_input_2, test_expected_2),
])
def test_manual_weights_bitwidth_selection(self, inputs, expected):
super().test_manual_weights_bitwidth_selection(inputs, expected)

class TestManualPositionalAttrWeightsBitwidthSelectionByLayerName(TestManualWeightsBitwidthSelectionByLayerName):
def get_float_model(self):
return get_float_model_with_constants()

def get_tpc(self, kernel_n_bits, bias_n_bits):
_, _, default_config = get_op_qco(kernel_n_bits, bias_n_bits)
tpc = generate_tpc_pos_attr_local(default_config)
return tpc

# (LayerType, bit width, attribute)
test_input_1 = (NodeNameFilter("add"), 8, POS_ATTR, 8, 8)
test_input_2 = (NodeNameFilter("sub"), [2], [POS_ATTR], 8, 8)

test_expected_1 = (['add'], [16])
test_expected_2 = (['sub'], [2])

@pytest.mark.parametrize(("inputs", "expected"), [
(test_input_1, test_expected_1),
(test_input_2, test_expected_2),
])
def test_manual_weights_bitwidth_selection(self, inputs, expected):
super().test_manual_weights_bitwidth_selection(inputs, expected)
Loading