-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Description
in paper, the ista algorithm described as rl+1 = max{(rl-u▽l),0}*sgn(rl-u▽l) , but in sgd.py, the ista algorithm described as:
x = p.data.add(-group['lr'],d_p)
x = torch.clamp((torch.abs(x) - ista), min=0.)
p.data = x * torch.sign(x)
I think the true code should be:
x = p.data.add(-group['lr'],d_p)
y = torch.clamp((torch.abs(x) - ista), min=0.)
p.data = y * torch.sign(x)
Metadata
Metadata
Assignees
Labels
No labels