-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathwavelet.py
More file actions
40 lines (33 loc) · 1.6 KB
/
wavelet.py
File metadata and controls
40 lines (33 loc) · 1.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# -*- coding: utf-8 -*-
import pywt
import torch
from torch.autograd import Variable
w=pywt.Wavelet('db1')
dec_hi = torch.Tensor(w.dec_hi[::-1])
dec_lo = torch.Tensor(w.dec_lo[::-1])
rec_hi = torch.Tensor(w.rec_hi)
rec_lo = torch.Tensor(w.rec_lo)
filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1)/2.0,
dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1),
dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1),
dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)
inv_filters = torch.stack([rec_lo.unsqueeze(0)*rec_lo.unsqueeze(1)*2.0,
rec_lo.unsqueeze(0)*rec_hi.unsqueeze(1),
rec_hi.unsqueeze(0)*rec_lo.unsqueeze(1),
rec_hi.unsqueeze(0)*rec_hi.unsqueeze(1)], dim=0)
def wt(vimg):
padded = vimg
res = torch.zeros(vimg.shape[0],4*vimg.shape[1],int(vimg.shape[2]/2),int(vimg.shape[3]/2))
res = res.cuda()
for i in range(padded.shape[1]):
res[:,4*i:4*i+4] = torch.nn.functional.conv2d(padded[:,i:i+1], Variable(filters[:,None].cuda(),requires_grad=True),stride=2)
res[:,4*i+1:4*i+4] = (res[:,4*i+1:4*i+4]+1)/2.0
return res
def iwt(vres):
res = torch.zeros(vres.shape[0],int(vres.shape[1]/4),int(vres.shape[2]*2),int(vres.shape[3]*2))
res = res.cuda()
for i in range(res.shape[1]):
vres[:,4*i+1:4*i+4]=2*vres[:,4*i+1:4*i+4]-1
temp = torch.nn.functional.conv_transpose2d(vres[:,4*i:4*i+4], Variable(inv_filters[:,None].cuda(),requires_grad=True),stride=2)
res[:,i:i+1,:,:] = temp
return res