-
Notifications
You must be signed in to change notification settings - Fork 280
Description
Hi @jeshraghian ,
I am trying to do object detection using lidar point clouds as BEVs using SNNs.
My model looks like this:
`
"""
Module contains Pytorch models
"""
import torch
import torch.nn as nn
import snntorch
import snntorch.spikegen as sgen
import torchsummary
filters_block_index_dict = {1: 16, 2: 32, 3: 64, 4: 128, 5: 256}
class SResBlock(nn.Module):
"""
Implements a spiking residual block.
Pytorch layers produce currents.
Leaky (LIF) layers take mem. pot. and cur. input and produce mem. pot. and cur. as outputs
"""
def __init__(self, n_channels_in, block_index):
"""
Default constr.
"""
super(SResBlock, self).__init__()
self.block_index = block_index
self.filters = filters_block_index_dict[block_index]
self.lif_beta = 0.85
# self.c_skip = nn.Conv2d(in_channels=n_channels_in, out_channels=self.fil)
self.c1 = nn.Conv2d(in_channels=n_channels_in, out_channels=self.filters, kernel_size=3, padding=3//2, bias=False)
self.b1 = nn.BatchNorm2d(self.filters)
self.lif1 = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)
self.c2 = nn.Conv2d(in_channels=self.filters, out_channels=self.filters, kernel_size=3, padding=3//2, bias=False)
self.b2 = nn.BatchNorm2d(self.filters)
self.lif2 = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)
self.pool = nn.Conv2d(self.filters + n_channels_in, self.filters, kernel_size=5, stride=2, padding=2)
self.lif_out = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)
def forward(self, x):
"""
Passes input through network
Conv. layers take spike inputs and produce float currents
Leaky layers take the currents and produce spikes [1, 0]
"""
mem1 = self.lif1.init_leaky()
mem2 = self.lif2.init_leaky()
mem_out = self.lif_out.init_leaky()
cur1 = self.b1(self.c1(x)) # floats
spk1, mem1 = self.lif1(cur1, mem1) # [1, 0] spikes and float mem. pot
cur2 = self.b2(self.c2(spk1))
spk2, mem2 = self.lif2(cur2, mem2)
spk2 = snntorch.torch.cat([spk2, x], dim=1)
cur_pool = self.pool(spk2)
spk_out, mem_out = self.lif_out(cur_pool, mem_out)
return spk_out, mem_out
class SEncoder4(nn.Module):
"""
Implements a spiking encoder
"""
def init(self, n_channels_in):
"""
set's up the arch
"""
super(SEncoder4, self).init()
self.res1 = SResBlock(n_channels_in, 1)
self.res2 = SResBlock(filters_block_index_dict[1], 2)
self.res3 = SResBlock(filters_block_index_dict[2], 3)
self.res4 = SResBlock(filters_block_index_dict[3], 4)
def forward(self, x):
"""
Pass input through the network
"""
x, mem = self.res1(x)
x, mem = self.res2(x)
x, mem = self.res3(x)
x, mem = self.res4(x)
return x, mem
class SHeads(nn.Module):
"""
Implements final
"""
def init(self, n_channels_in, n_ch_out, lif_beta=0.80, is_regr=False):
"""
set's up the arch
"""
super(SHeads, self).init()
self.lif_beta = lif_beta
self.is_regr = is_regr
self.c1 = nn.Conv2d(n_channels_in, 32, kernel_size=3, padding=1, bias=False)
self.b1 = nn.BatchNorm2d(32)
self.lif1 = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)
self.c2 = nn.Conv2d(32, 16, kernel_size=1, bias=False)
self.b2 = nn.BatchNorm2d(16)
self.lif2 = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)
self.cout = nn.Conv2d(16, n_ch_out, kernel_size=1, bias=False)
if self.is_regr is True:
self.lif_out = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0, reset_mechanism="none")
else:
self.lif_out = snntorch.Leaky(beta=self.lif_beta, learn_beta=True, threshold=1.0)
def forward(self, x):
"""
Pass input through the network
"""
mem1 = self.lif1.init_leaky()
mem2 = self.lif2.init_leaky()
mem_out = self.lif_out.init_leaky()
cur1 = self.b1(self.c1(x))
spk1, mem1 = self.lif1(cur1, mem1)
cur2 = self.b2(self.c2(spk1))
spk2, mem2 = self.lif2(cur2, mem2)
cur_out = self.cout(spk2)
spk_out, mem_out = self.lif_out(cur_out, mem_out)
return spk_out, mem_out
class SBEVDetNet(nn.Module):
"""
Class implements a
"""
def init(self, n_channels_in, n_timesteps=10):
"""
Method sets up the SNN body
"""
super(SBEVDetNet, self).init()
self.n_timesteps = n_timesteps
self.enc = SEncoder4(n_channels_in)
self.head_kp = SHeads(128, 2)
self.head_kp_offset = SHeads(128, 2, is_regr=True)
self.head_hwl = SHeads(128, 3, is_regr=True)
self.head_rot = SHeads(128, 37)
def forward(self, x_spiking):
"""
Passes spiking input through the network
Spikes are used for classification heads while mem. pot. are used for regr. tasks
"""
# mem_rec_kp = []
spk_rec_kp = []
spk_rec_rot = []
mem_rec_kp_off = []
mem_rec_hwl = []
# pass each timestep separately
for t in range(self.n_timesteps):
x_spk = x_spiking[t]
x_spk, x_mem = self.enc(x_spk)
xout_kp_spk, xout_kp_mem = self.head_kp(x_spk)
spk_rec_kp.append(xout_kp_spk)
xout_kp_off_spk, xout_kp_off_mem = self.head_kp_offset(x_spk)
mem_rec_kp_off.append(xout_kp_off_mem)
xout_hwl_spk, xout_hwl_mem = self.head_hwl(x_spk)
mem_rec_hwl.append(xout_hwl_mem)
xout_rot_spk, xout_rot_mem = self.head_rot(x_spk)
spk_rec_rot.append(xout_rot_spk)
xout_kp_spk = torch.stack(spk_rec_kp)
xout_kp_off_mem = torch.stack(mem_rec_kp_off)
xout_hwl_mem = torch.stack(mem_rec_hwl)
xout_rot_spk = torch.stack(spk_rec_rot)
return xout_kp_spk, xout_kp_off_mem, xout_hwl_mem, xout_rot_spk
if name == "main":
# res_block = SResBlock(5, 1)
# torchsummary.summary(res_block, (5, 416, 416))
enc = SBEVDetNet(n_channels_in=5, n_timesteps=10)
t = torch.randn((10, 4, 5, 416, 416))
enc(t)
`
My training loop looks like this:
`
def train_spiking_bevdetnet(batch_size, num_epochs, n_timesteps):
"""
Function launches training.
We use spikes for classification - rate based => target neuron spikes max. over time steps
membrane pot. used for regression - use torch SmoothL1Loss
Initially try rate codes for inputs. Later try using the NN first layers to do the spike coding
Also try population coding
Try making all targets as classification targets if needed
- H,W-> bins[0.7, 1.6, @0.1], L -> bins[2.5, 5.1, @0.1]
- offsets: bins[0, 16, @1], 16: scale factor used for scaling masks
"""
print(f"Training spiking BEVDETNet with {n_timesteps} timesteps")
DBG = False
snet = storch_models.SBEVDetNet(n_channels_in=5, n_timesteps=n_timesteps).cuda()
# torchsummary.summary(snet, (5, 416, 416))
dataclass = kitti_dataset.SpikingKITTIDataset(path_train_npy=settings.path_kitti_train_npy,
path_val_npy=settings.path_kitti_val_npy,
dataset_mode="train", debug_flags=False, scaled_masks_factor=16)
dataloader = DataLoader(dataclass, batch_size=batch_size, num_workers=4, drop_last=True, shuffle=True)
optim = torch.optim.Adam(snet.parameters())
loss_fn_cls = ce_rate_loss() # TODO
# loss_fn_reg = mse_membrane_loss(time_var_targets=False,
# on_target=0.85, off_target=0.15)
loss_fn_reg = torch.nn.SmoothL1Loss()
for epoch in range(num_epochs):
for index, data in enumerate(dataloader):
train_x_bev = data[0].permute(0, 3, 1, 2).cuda()
train_y_s_kp = data[1].cuda()
train_y_s_kp_off = data[2].permute(0, 3, 1, 2).cuda()
train_y_s_hwl = data[3].permute(0, 3, 1, 2).cuda()
train_y_s_rot = data[4].cuda()
spk_train_x_bev = sgen.rate(train_x_bev, num_steps=n_timesteps).cuda() # convert to spikes
spk_yhat_kp, mem_yhat_kp_off, mem_yhat_hwl, spk_yhat_rot = snet(spk_train_x_bev)
loss_kp = loss_fn_cls(spk_yhat_kp, train_y_s_kp)
loss_rot = loss_fn_cls(spk_yhat_rot, train_y_s_rot)
# loop over alll regression outputs and copute loss for each time step
loss_hwl = torch.tensor(0, device="cuda:0")
loss_offset = torch.tensor(0, device="cuda:0")
for t_step in range(n_timesteps):
t_loss_hwl = loss_fn_reg(mem_yhat_hwl[t_step], train_y_s_hwl)
t_loss_kp_off = loss_fn_reg(mem_yhat_kp_off[t_step], train_y_s_kp_off)
loss_hwl = loss_hwl + t_loss_hwl
loss_offset = loss_offset + t_loss_kp_off
total_loss = loss_kp + 0.97*loss_rot + 0.95*loss_offset + 0.90*loss_hwl
optim.zero_grad()
total_loss.backward()
optim.step()
print(f"Epoch -> {epoch}, loss -> {total_loss.item()}")
`
However, I don't see the train loss reducing. Please help me out.
I have also added the files as txt