diff --git a/musco/pytorch/compressor/layers/cp3_conv1d.py b/musco/pytorch/compressor/layers/cp3_conv1d.py new file mode 100644 index 0000000..6761d07 --- /dev/null +++ b/musco/pytorch/compressor/layers/cp3_conv1d.py @@ -0,0 +1,186 @@ +import numpy as np +import torch +from torch import nn +from argparse import Namespace + +import tensorly as tl +from tensorly.decomposition import parafac +from tensorly.kruskal_tensor import kruskal_to_tensor + +from musco.pytorch.compressor.rank_estimation.estimator import estimate_rank_for_compression_rate +from .base import DecomposedLayer + +tl.set_backend('pytorch') + + +class CP3DecomposedLayerConv1D(nn.Module, DecomposedLayer): + """Convolutional 1D layer with a kernel (k spacial size, k>1) represented in CP3 format. + + References + ---------- + .. [1] Lebedev, Vadim, et al. (2014). "Speeding-up convolutional neural networks using fine-tuned cp-decomposition."Proceedings of the International Conference on Learning Representations. + """ + def __init__(self, + layer, + layer_name, + algo_kwargs={}, + **compr_kwargs): + + nn.Module.__init__(self) + DecomposedLayer.__init__(self, layer, layer_name, algo_kwargs=algo_kwargs) + + assert compr_kwargs['decomposition'] in ['cp3_conv1d', 'qcp3_conv1d'] + self.min_rank = 2 + + self.cin = None + self.cout = None + self.kernel_size = None + self.padding = None + self.stride = None + self.device = None + self.dilation = None + + # Initialize layer parameters + self.init_layer_params() + self.init_device() + + # Estimate rank for tensor approximation and build new layers + self.estimate_rank(**compr_kwargs) + self.build_new_layers() + + # Compute weights for new layers, initialize new layers + self.init_new_layers(*self.compute_new_weights(*self.extract_weights(), algo_kwargs)) + + self.layer = None + self.__delattr__('layer') + + + def init_layer_params(self): + if isinstance(self.layer, nn.Sequential): + self.cin = self.layer[0].in_channels + self.cout = self.layer[-1].out_channels + + self.kernel_size = self.layer[1].kernel_size + self.padding = self.layer[1].padding + self.stride = self.layer[1].stride + + self.dilation = self.layer.dilation + + + else: + if not isinstance(self.layer, nn.Conv1d): + raise AttributeError('Only convolutional layer can be decomposed') + self.cin = self.layer.in_channels + self.cout = self.layer.out_channels + + self.kernel_size = self.layer.kernel_size + self.padding = self.layer.padding + self.stride = self.layer.stride + + self.dilation = self.layer.dilation + + + def estimate_rank(self, **compr_kwargs): + compr_kwargs = Namespace(**compr_kwargs) + + if compr_kwargs.rank_selection == 'param_reduction': + if isinstance(self.layer, nn.Sequential): + prev_rank = self.layer[0].out_channels + else: + prev_rank = None + + tensor_shape = (self.cout, self.cin, *self.kernel_size) + self.rank = estimate_rank_for_compression_rate(tensor_shape, + rate = compr_kwargs.param_reduction_rate, + tensor_format = compr_kwargs.decomposition, + prev_rank = prev_rank, + min_rank = self.min_rank) + elif compr_kwargs.rank_selection == 'manual': + i = compr_kwargs.curr_compr_iter + self.rank = compr_kwargs.manual_rank[i] + + + def extract_weights(self): + if isinstance(self.layer, nn.Sequential): + w_cin = self.layer[0].weight.data + w_w = self.layer[1].weight.data + w_cout = self.layer[2].weight.data + + try: + bias = self.layer[-1].bias.data + except: + bias = None + + f_w = w_w.squeeze().t() + f_cin = w_cin.squeeze().t() + f_cout = w_cout.squeeze() + + weight = [f_cout, f_cin, f_w] + + else: + weight = self.layer.weight.data + try: + bias = self.layer.bias.data + except: + bias = None + return weight, bias + + + def compute_new_weights(self, weight, bias, algo_kwargs={}): + if 'qscheme' in algo_kwargs.keys(): + from tensorly.decomposition import quantized_parafac as parafac + print('Quantized') + else: + from tensorly.decomposition import parafac + + if isinstance(self.layer, nn.Sequential): + lmbda, (f_cout, f_cin, f_w) = parafac(kruskal_to_tensor((None, weight)), + self.rank, + **algo_kwargs) + else: + lmbda, (f_cout, f_cin, f_w) = parafac(weight, + self.rank, + **algo_kwargs) + + +# # Reshape factor matrices to 3D weight tensors +# f_cin: (cin, rank) -> (rank, cin, 1) +# f_w: (w, rank) -> (rank, 1, w) +# f_cout: (count, rank) -> (count, rank, 1) + + # Pytorh case + f_cin = (lmbda * f_cin).t().unsqueeze_(2).contiguous().to(self.device) + f_w = f_w.t().unsqueeze_(1).contiguous().to(self.device) + f_cout = f_cout.unsqueeze_(2).contiguous().to(self.device) + + return [f_cin, f_w, f_cout], [None, None, bias] + + + def build_new_layers(self): + layers = [] + layers.append(nn.Conv1d(in_channels=self.cin, + out_channels=self.rank, + kernel_size = 1)) + + layers.append(nn.Conv1d(in_channels = self.rank, + out_channels=self.rank, + kernel_size = self.kernel_size, + groups = self.rank, + padding = self.padding, + stride = self.stride, + dilation = self.dilation + )) + + layers.append(nn.Conv1d(in_channels = self.rank, + out_channels = self.cout, + kernel_size = 1)) + + self.new_layers = nn.Sequential() + + for j, l in enumerate(layers): + self.new_layers.add_module('{}-{}'.format(self.layer_name, j), l) + + + def forward(self, x): + x = self.new_layers(x) + return x \ No newline at end of file diff --git a/musco/pytorch/compressor/layers/utils.py b/musco/pytorch/compressor/layers/utils.py index 189da65..2af42b7 100644 --- a/musco/pytorch/compressor/layers/utils.py +++ b/musco/pytorch/compressor/layers/utils.py @@ -15,6 +15,13 @@ def get_all_algo_kwargs(): """ all_algo_kwargs = defaultdict(dict) + + all_algo_kwargs['cp3_conv1d'] = vars(Namespace(n_iter_max=5000, + init='random', + tol=1e-8, + svd = None, + cvg_criterion = 'rec_error', + normalize_factors = True)) all_algo_kwargs['cp3'] = vars(Namespace(n_iter_max=5000, init='random', @@ -24,18 +31,36 @@ def get_all_algo_kwargs(): normalize_factors = True)) all_algo_kwargs['cp4'] = vars(Namespace(n_iter_max=5000, - init='random', - tol=1e-8, - svd = None, - stop_criterion = 'rec_error_deviation', - normalize_factors = True)) - + init='random', + tol=1e-8, + svd = None, + stop_criterion = 'rec_error_deviation', + normalize_factors = True)) + all_algo_kwargs['tucker2'] = vars(Namespace(init='nvecs')) all_algo_kwargs['svd'] = vars(Namespace(full_matrices=False)) QSCHEME = torch.per_channel_symmetric DIM = 1 + + all_algo_kwargs['qcp3_conv1d'] = dict( + **vars(Namespace(n_iter_max=500, + init='random', + tol=1e-8, + svd=None, + normalize_factors=True, + )), + **vars(Namespace(dtype=torch.qint8, + qscheme=QSCHEME, + dim=DIM, + )), + **vars(Namespace(qmodes=[0, 1, 2], + return_scale_zeropoint=False, + stop_criterion='rec_error_deviation', + return_qerrors=False, + )) + ) all_algo_kwargs['qcp3'] = dict( **vars(Namespace(n_iter_max=500, diff --git a/musco/pytorch/compressor/utils.py b/musco/pytorch/compressor/utils.py index 90fba76..b821975 100755 --- a/musco/pytorch/compressor/utils.py +++ b/musco/pytorch/compressor/utils.py @@ -12,6 +12,7 @@ from .layers.cp4 import CP4DecomposedLayer from .layers.svd_layer import SVDDecomposedLayer, SVDDecomposedConvLayer from .layers.base import DecomposedLayer +from .layers.cp3_conv1d import CP3DecomposedLayerConv1D def get_compressed_model(model, @@ -87,6 +88,10 @@ def get_compressed_model(model, elif decomposition in ['cp4', 'qcp4']: decomposed_layer = CP4DecomposedLayer(layer, subm_names[-1], algo_kwargs, **compr_kwargs) + elif decomposition in ['cp3_conv1d', 'qcp3_conv1d']: + decomposed_layer = CP3DecomposedLayerConv1D(layer, subm_names[-1], algo_kwargs, **compr_kwargs) + + elif decomposition == 'svd': if layer_type == nn.Conv2d: decomposed_layer = SVDDecomposedConvLayer(layer, subm_names[-1], algo_kwargs, **compr_kwargs) @@ -113,6 +118,24 @@ def get_compressed_model(model, return compressed_model +# def standardize_model(model): +# """Replace custom layers with standard nn.Module layers. + +# Relplace each layer of type DecomposedLayer with nn.Sequential. +# """ + +# for mname, m in model.named_modules(): +# if isinstance(m, DecomposedLayer): +# subm_names = mname.strip().split('.') + +# if len(subm_names) > 1: +# m = model.__getattr__(subm_names[0]) +# for s in subm_names[1:-1]: +# m = m.__getattr__(s) +# m.__setattr__(subm_names[-1], m.__getattr__(subm_names[-1]).new_layers) +# else: +# model.__setattr__(subm_names[-1], m.new_layers) + def standardize_model(model): """Replace custom layers with standard nn.Module layers. @@ -127,6 +150,23 @@ def standardize_model(model): m = model.__getattr__(subm_names[0]) for s in subm_names[1:-1]: m = m.__getattr__(s) + + m_dict = m.__getattr__(subm_names[-1]).__dict__ + m_dict_new_layer = m.__getattr__(subm_names[-1]).new_layers.__dict__ + m.__setattr__(subm_names[-1], m.__getattr__(subm_names[-1]).new_layers) + + m = m.__getattr__(subm_names[-1]) + for field in set(m_dict.keys()) - set(m_dict_new_layer.keys()): + m.__setattr__(field, m_dict[field]) + else: - model.__setattr__(subm_names[-1], m.new_layers) \ No newline at end of file + + m_dict = model.__getattr__(subm_names[-1]).__dict__ + m_dict_new_layer = model.__getattr__(subm_names[-1]).new_layers.__dict__ + + model.__setattr__(subm_names[-1], m.new_layers) + + m = model.__getattr__(subm_names[-1]) + for field in set(m_dict.keys()) - set(m_dict_new_layer.keys()): + m.__setattr__(field, m_dict[field]) \ No newline at end of file