Skip to content

Training loss does not reduce for a moderately sized network #377

@SM1991CODES

Description

@SM1991CODES

Hi @jeshraghian ,

I am trying to do object detection using lidar point clouds as BEVs using SNNs.

My model looks like this:

`
"""
Module contains Pytorch models
"""

import torch
import torch.nn as nn
import snntorch
import snntorch.spikegen as sgen
import torchsummary

filters_block_index_dict = {1: 16, 2: 32, 3: 64, 4: 128, 5: 256}

class SResBlock(nn.Module):
"""
Implements a spiking residual block.
Pytorch layers produce currents.
Leaky (LIF) layers take mem. pot. and cur. input and produce mem. pot. and cur. as outputs

"""
def __init__(self, n_channels_in, block_index):
    """
    Default constr.
    """
    super(SResBlock, self).__init__()
    self.block_index = block_index
    self.filters = filters_block_index_dict[block_index]
    self.lif_beta = 0.85
    
    # self.c_skip = nn.Conv2d(in_channels=n_channels_in, out_channels=self.fil)
    self.c1 = nn.Conv2d(in_channels=n_channels_in, out_channels=self.filters, kernel_size=3, padding=3//2, bias=False)
    self.b1 = nn.BatchNorm2d(self.filters)
    self.lif1 = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)

    self.c2 = nn.Conv2d(in_channels=self.filters, out_channels=self.filters, kernel_size=3, padding=3//2, bias=False)
    self.b2 = nn.BatchNorm2d(self.filters)
    self.lif2 = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)

    self.pool = nn.Conv2d(self.filters + n_channels_in, self.filters, kernel_size=5, stride=2, padding=2)
    self.lif_out = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)

def forward(self, x):
    """
    Passes input through network
    Conv. layers take spike inputs and produce float currents
    Leaky layers take the currents and produce spikes [1, 0]
    """

    mem1 = self.lif1.init_leaky()
    mem2 = self.lif2.init_leaky()
    mem_out = self.lif_out.init_leaky()

    cur1 = self.b1(self.c1(x))  # floats
    spk1, mem1 = self.lif1(cur1, mem1)  # [1, 0] spikes and float mem. pot

    cur2 = self.b2(self.c2(spk1))
    spk2, mem2 = self.lif2(cur2, mem2)
    spk2 = snntorch.torch.cat([spk2, x], dim=1)

    cur_pool = self.pool(spk2)
    spk_out, mem_out = self.lif_out(cur_pool, mem_out)
    return spk_out, mem_out

class SEncoder4(nn.Module):
"""
Implements a spiking encoder
"""
def init(self, n_channels_in):
"""
set's up the arch
"""
super(SEncoder4, self).init()

    self.res1 = SResBlock(n_channels_in, 1)
    self.res2 = SResBlock(filters_block_index_dict[1], 2)
    self.res3 = SResBlock(filters_block_index_dict[2], 3)
    self.res4 = SResBlock(filters_block_index_dict[3], 4)

def forward(self, x):
    """
    Pass input through the network
    """
    x, mem = self.res1(x)
    x, mem = self.res2(x)
    x, mem = self.res3(x)
    x, mem = self.res4(x)
    return x, mem

class SHeads(nn.Module):
"""
Implements final
"""
def init(self, n_channels_in, n_ch_out, lif_beta=0.80, is_regr=False):
"""
set's up the arch
"""
super(SHeads, self).init()
self.lif_beta = lif_beta
self.is_regr = is_regr

    self.c1 = nn.Conv2d(n_channels_in, 32, kernel_size=3, padding=1, bias=False)
    self.b1 = nn.BatchNorm2d(32)
    self.lif1 = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)

    self.c2 = nn.Conv2d(32, 16, kernel_size=1, bias=False)
    self.b2 = nn.BatchNorm2d(16)
    self.lif2 = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)

    self.cout = nn.Conv2d(16, n_ch_out, kernel_size=1, bias=False)

    if self.is_regr is True:
        self.lif_out = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0, reset_mechanism="none")
    else:
        self.lif_out = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)

def forward(self, x):
    """
    Pass input through the network
    """

    mem1 = self.lif1.init_leaky()
    mem2 = self.lif2.init_leaky()
    mem_out = self.lif_out.init_leaky()

    cur1 = self.b1(self.c1(x))
    spk1, mem1 = self.lif1(cur1, mem1)

    cur2 = self.b2(self.c2(spk1))
    spk2, mem2 = self.lif2(cur2, mem2)

    cur_out = self.cout(spk2)
    spk_out, mem_out = self.lif_out(cur_out, mem_out)
    return spk_out, mem_out

class SBEVDetNet(nn.Module):
"""
Class implements a
"""
def init(self, n_channels_in, n_timesteps=10):
"""
Method sets up the SNN body
"""
super(SBEVDetNet, self).init()
self.n_timesteps = n_timesteps

    self.enc = SEncoder4(n_channels_in)
    self.head_kp = SHeads(128, 2)
    self.head_kp_offset = SHeads(128, 2, is_regr=True)
    self.head_hwl = SHeads(128, 3, is_regr=True)
    self.head_rot = SHeads(128, 37)

def forward(self, x_spiking):
    """
    Passes spiking input through the network
    Spikes are used for classification heads while mem. pot. are used for regr. tasks
    """

    # mem_rec_kp = []
    spk_rec_kp = []
    spk_rec_rot = []
    
    mem_rec_kp_off = []
    mem_rec_hwl = []
    
    # pass each timestep separately
    for t in range(self.n_timesteps):
        x_spk = x_spiking[t]
        x_spk, x_mem = self.enc(x_spk)

        xout_kp_spk, xout_kp_mem = self.head_kp(x_spk)
        spk_rec_kp.append(xout_kp_spk)

        xout_kp_off_spk, xout_kp_off_mem = self.head_kp_offset(x_spk)
        mem_rec_kp_off.append(xout_kp_off_mem)

        xout_hwl_spk, xout_hwl_mem = self.head_hwl(x_spk)
        mem_rec_hwl.append(xout_hwl_mem)
        
        xout_rot_spk, xout_rot_mem = self.head_rot(x_spk)
        spk_rec_rot.append(xout_rot_spk)

    xout_kp_spk = torch.stack(spk_rec_kp)
    xout_kp_off_mem = torch.stack(mem_rec_kp_off)
    xout_hwl_mem = torch.stack(mem_rec_hwl)
    xout_rot_spk = torch.stack(spk_rec_rot)
    return xout_kp_spk, xout_kp_off_mem, xout_hwl_mem, xout_rot_spk

if name == "main":

# res_block = SResBlock(5, 1)
# torchsummary.summary(res_block, (5, 416, 416))

enc = SBEVDetNet(n_channels_in=5, n_timesteps=10)
t = torch.randn((10, 4, 5, 416, 416))
enc(t)

`

My training loop looks like this:

`
def train_spiking_bevdetnet(batch_size, num_epochs, n_timesteps):
"""
Function launches training.
We use spikes for classification - rate based => target neuron spikes max. over time steps
membrane pot. used for regression - use torch SmoothL1Loss

Initially try rate codes for inputs. Later try using the NN first layers to do the spike coding
Also try population coding
Try making all targets as classification targets if needed
    - H,W-> bins[0.7, 1.6, @0.1], L -> bins[2.5, 5.1, @0.1]
    - offsets: bins[0, 16, @1], 16: scale factor used for scaling masks

"""

print(f"Training spiking BEVDETNet with {n_timesteps} timesteps")
DBG = False
snet = storch_models.SBEVDetNet(n_channels_in=5, n_timesteps=n_timesteps).cuda()
# torchsummary.summary(snet, (5, 416, 416))

dataclass = kitti_dataset.SpikingKITTIDataset(path_train_npy=settings.path_kitti_train_npy,
                                              path_val_npy=settings.path_kitti_val_npy,
                                              dataset_mode="train", debug_flags=False, scaled_masks_factor=16)
dataloader = DataLoader(dataclass, batch_size=batch_size, num_workers=4, drop_last=True, shuffle=True)

optim = torch.optim.Adam(snet.parameters())
loss_fn_cls = ce_rate_loss()  # TODO
# loss_fn_reg = mse_membrane_loss(time_var_targets=False, 
#                                 on_target=0.85, off_target=0.15)
loss_fn_reg = torch.nn.SmoothL1Loss()

for epoch in range(num_epochs):
    for index, data in enumerate(dataloader):

        train_x_bev = data[0].permute(0, 3, 1, 2).cuda()
        train_y_s_kp = data[1].cuda()
        train_y_s_kp_off = data[2].permute(0, 3, 1, 2).cuda()
        train_y_s_hwl = data[3].permute(0, 3, 1, 2).cuda()
        train_y_s_rot = data[4].cuda()
        spk_train_x_bev = sgen.rate(train_x_bev, num_steps=n_timesteps).cuda()  # convert to spikes

        spk_yhat_kp, mem_yhat_kp_off, mem_yhat_hwl, spk_yhat_rot = snet(spk_train_x_bev) 
        
        loss_kp = loss_fn_cls(spk_yhat_kp, train_y_s_kp)
        loss_rot = loss_fn_cls(spk_yhat_rot, train_y_s_rot)

        # loop over alll regression outputs and copute loss for each time step
        loss_hwl = torch.tensor(0, device="cuda:0")
        loss_offset = torch.tensor(0, device="cuda:0")
        for t_step in range(n_timesteps):
            t_loss_hwl = loss_fn_reg(mem_yhat_hwl[t_step], train_y_s_hwl)
            t_loss_kp_off = loss_fn_reg(mem_yhat_kp_off[t_step], train_y_s_kp_off)
            loss_hwl = loss_hwl + t_loss_hwl
            loss_offset = loss_offset + t_loss_kp_off

        total_loss = loss_kp + 0.97*loss_rot + 0.95*loss_offset + 0.90*loss_hwl
        
        optim.zero_grad()
        total_loss.backward()
        optim.step()
        print(f"Epoch -> {epoch}, loss -> {total_loss.item()}")

`

However, I don't see the train loss reducing. Please help me out.
I have also added the files as txt

storch_models.txt

train_models.txt

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions