Skip to content

Training not converging on DenseNet-bottleneck #3

@lizhenstat

Description

@lizhenstat

Hi, I want to test your method on DenseNet with bottleneck structure on CIFAR100 (conv1x1 --> conv3x3)
I follow the code of densenet_svd.py and hinge_resnet_bottleneck.py,
mainly changing the following function

def compress_module_param(module, percentage, threshold):
    # Bias in None. So things becomes easier.
    # get the body
    '''
    # transition 
    (0): BatchNorm2d(168, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU()
    (2): Conv2d(168, 84, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (3): Conv2d(84, 84, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    '''
    if isinstance(module, Transition):
        body = module

        conv1 = body._modules['2']
        conv2 = body._modules['3']

        ws1 = conv1.weight.data.shape
        weight1 = conv1.weight.data.squeeze().t()

        ws2 = conv2.weight.data.shape
        weight2 = conv2.weight.data.squeeze().t()

        # calculate pindex
        _, pindex = get_nonzero_index(weight1, dim='output', counter=1, percentage=percentage, threshold=threshold)

        pl = pindex.shape[0]
        weight1 = torch.index_select(weight1, dim=1, index=pindex) 
        conv1.weight = nn.Parameter(weight1.t().view(pl, ws1[1], ws1[2], ws1[3]))
        conv1.out_channels = pl

        # compress conv2
        conv2.weight = nn.Parameter(torch.index_select(weight2, dim=0, index=pindex).t().view(ws2[0], pl, ws2[2], ws2[3]))
        conv2.in_channels = pl

    elif isinstance(module, BottleNeck):
        # with torchsnooper.snoop():
        body = module._modules['body']
        # conv1x1
        conv1 = body._modules['2']
        batchnorm1 = body._modules['3'] # conv1-output 对应的 batchnorm
        conv2 = body._modules['5']

        
        # get conv weights
        ws1 = conv1.weight.data.shape
        weight1 = conv1.weight.data.squeeze().t()
        
        bn_weight1 = batchnorm1.weight.data
        bn_bias1 = batchnorm1.bias.data
        bn_mean1 = batchnorm1.running_mean.data
        bn_var1 = batchnorm1.running_var.data
        
        ws2 = conv2.weight.data.shape
        weight2 = conv2.weight.data.view(ws2[0], ws2[1] * ws2[2] * ws2[3]).t()
        
        # selection compressed channels
        _, pindex1 = get_nonzero_index(weight1, dim='output', counter=1, percentage=percentage, threshold=threshold)
        pl1 = len(pindex1)
        conv1.weight = nn.Parameter(torch.index_select(weight1, dim=1, index=pindex1).t().view(pl1, -1, 1, 1))
        conv1.out_channels = pl1

        # batchnorm1
        batchnorm1.weight = nn.Parameter(torch.index_select(bn_weight1, dim=0, index=pindex1)) 
        batchnorm1.bias = nn.Parameter(torch.index_select(bn_bias1, dim=0, index=pindex1))
        batchnorm1.running_mean = torch.index_select(bn_mean1, dim=0, index=pindex1)
        batchnorm1.running_var = torch.index_select(bn_var1, dim=0, index=pindex1)
        batchnorm1.num_features = pl1
        
        # conv2
        index = torch.repeat_interleave(pindex1, ws2[2] * ws2[3]) * ws2[2] * ws2[3] + \
                torch.tensor(range(0, ws2[2] * ws2[3])).repeat(pindex1.shape[0]).cuda()
        weight2 = torch.index_select(weight2, dim=0, index=index)
        # weight2 = torch.index_select(weight2, dim=1, index=pindex3)
        conv2.weight = nn.Parameter(weight2.view(ws2[0], pl1, 3, 3))
        conv2.in_channels = pl1
        # exit(0)
    else:
        raise NotImplementedError('Do not need to compress the layer ' + module.__class__.__name__)

while testing the model using default parameters, the top-1 test error change as follows:

test

Did you test your model on DenseNet-bottleck during experiment?
I was wondering is there something wrong with my code, if not, why the testing loss behavior like this?

Thanks for your time and looking forward to your reply.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions