Skip to content

Commit df52e51

Browse files
chyomin06fracape
authored andcommitted
[feat] complete support of integer transposed conv2d
1 parent 67c9f41 commit df52e51

File tree

3 files changed

+97
-33
lines changed

3 files changed

+97
-33
lines changed

compressai_vision/model_wrappers/detectron2.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,50 @@ def forward(self, x: torch.Tensor):
104104
return x
105105

106106

107+
class ConvTranspose2d(IntTransposedConv2d):
108+
def __init__(self, *args, **kwargs) -> None:
109+
"""
110+
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
111+
112+
Args:
113+
norm (nn.Module, optional): a normalization layer
114+
activation (callable(Tensor) -> Tensor): a callable activation function
115+
116+
It assumes that norm layer is used before activation.
117+
"""
118+
119+
norm = kwargs.pop("norm", None)
120+
activation = kwargs.pop("activation", None)
121+
super().__init__(*args, **kwargs)
122+
123+
self.norm = norm
124+
self.activation = activation
125+
126+
def set_attributes(self, module):
127+
128+
if hasattr(module, "norm"):
129+
self.norm = module.norm
130+
131+
if hasattr(module, "activation"):
132+
self.activation = module.activation
133+
134+
if hasattr(module, "bias"):
135+
self.bias = module.bias
136+
137+
def forward(self, x: torch.Tensor):
138+
if not self.initified_weight_mode:
139+
x = self.transposedconv2d(x)
140+
else:
141+
x = self.integer_transposeconv2d(x)
142+
143+
if self.norm is not None:
144+
x = self.norm(x)
145+
if self.activation is not None:
146+
x = self.activation(x)
147+
148+
return x
149+
150+
107151
class Split_Points(Enum):
108152
def __str__(self):
109153
return str(self.value)
@@ -188,15 +232,15 @@ def size_divisibility(self):
188232

189233
def replace_conv2d_modules(self, module):
190234
for child_name, child_module in module.named_children():
191-
if type(child_module).__name__ in ["Conv2d", "TransposedConv2d"]:
235+
if type(child_module).__name__ in ["Conv2d", "ConvTranspose2d"]:
192236
if type(child_module).__name__ == "Conv2d":
193237
int_module = Conv2d(**child_module.__dict__)
194238
int_module.set_attributes(child_module)
195239
else:
196-
int_module = IntTransposedConv2d(**child_module.__dict__)
240+
int_module = ConvTranspose2d(**child_module.__dict__)
197241
int_module.set_attributes(child_module)
198242

199-
# Since regular list is used instead of ModuleList
243+
# Since regular list is used instead of ModuleList in Backbone
200244
if "fpn_lateral" in child_name or "fpn_output" in child_name:
201245
idx = re.findall(r"\d", child_name)
202246
assert len(idx) == 1
@@ -211,12 +255,16 @@ def replace_conv2d_modules(self, module):
211255

212256
setattr(module, child_name, int_module)
213257
else:
258+
# WATCH OUT RECURSIVE FUNCAITON CALLS
259+
# The funnction can be rewritten by specifically iterate for each module
260+
# including Conv2d and Trasnposed Conv2d.
261+
# type(module).__name__ in ["FPN", "BasicStem", "BottleneckBlock", "StandardRPNHead", "MaskRCNNConvUpsampledHead"]
214262
self.replace_conv2d_modules(child_module)
215263

216264
@staticmethod
217265
def quantize_weights(model):
218266
for _, m in model.named_modules():
219-
if type(m).__name__ == "Conv2d":
267+
if type(m).__name__ in ["Conv2d", "ConvTranspose2d"]:
220268
# print(f"Module name: {name} and type {type(m).__name__}")
221269
m.quantize_weights()
222270

compressai_vision/model_wrappers/intconv2d.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,15 @@ class IntConv2d(torch.nn.Conv2d):
3939
def __init__(self, *args, **kwargs) -> None:
4040
_nkwargs = copy.deepcopy(kwargs)
4141

42-
del _nkwargs["training"]
43-
del _nkwargs["transposed"]
44-
del _nkwargs["output_padding"]
42+
if _nkwargs.get("training") is not None:
43+
del _nkwargs["training"]
44+
45+
if _nkwargs.get("transposed") is not None:
46+
del _nkwargs["transposed"]
47+
48+
if _nkwargs.get("output_padding") is not None:
49+
del _nkwargs["output_padding"]
50+
4551
for name in kwargs.keys():
4652
if name.startswith("_"):
4753
del _nkwargs[name]
@@ -67,7 +73,7 @@ def quantize_weights(self):
6773
)
6874
_precision = 2 ** (23 + 1)
6975

70-
###### REFERENCE FROM VCMRS ######
76+
###### ADOPT VCMRS IMPLEMENTATION ######
7177
# sf const
7278
sf_const = 48
7379

@@ -94,14 +100,14 @@ def quantize_weights(self):
94100
self.bias.requires_grad = False # Just make sure
95101
self.bias.zero_()
96102

97-
###### END OF REFERENCE FROM VCMRS ######
103+
###### END OF THE REFERENCE IMPELEMENTATION OF THE INT CONVS IN VCMRS ######
98104

99105
def integer_conv2d(self, x: torch.Tensor):
100106
_dtype = x.dtype
101107
_cudnn_enabled = torch.backends.cudnn.enabled
102108
torch.backends.cudnn.enabled = False
103109

104-
###### REFERENCE FROM VCMRS ######
110+
###### ADOPT VCMRS IMPLEMENTATION ######
105111
# Calculate factor
106112
fx = 1
107113

@@ -124,15 +130,15 @@ def integer_conv2d(self, x: torch.Tensor):
124130
)
125131

126132
# x should be all integers
127-
out_x = out_x / (fx * self.fw.view(-1, 1, 1)).float()
133+
out_x = out_x / (fx * self.fw.to(out_x.device).view(-1, 1, 1)).float()
128134

129135
# apply bias in float format
130136
out_x = (
131-
(out_x.permute(0, 2, 3, 1) + self.float_bias)
137+
(out_x.permute(0, 2, 3, 1) + self.float_bias.to(out_x.device))
132138
.permute(0, 3, 1, 2)
133139
.contiguous()
134140
)
135-
###### REFERENCE FROM VCMRS ######
141+
###### END OF THE REFERENCE IMPELEMENTATION OF THE INT CONVS IN VCMRS ######
136142
torch.backends.cudnn.enabled = _cudnn_enabled
137143

138144
return out_x.to(_dtype)
@@ -150,12 +156,15 @@ def conv2d(self, x: torch.Tensor):
150156

151157

152158
class IntTransposedConv2d(torch.nn.ConvTranspose2d):
153-
def __init__(self, *args, **kwarg) -> None:
159+
def __init__(self, *args, **kwargs) -> None:
154160
_nkwargs = copy.deepcopy(kwargs)
155161

156-
del _nkwargs["training"]
157-
del _nkwargs["transposed"]
158-
del _nkwargs["output_padding"]
162+
if _nkwargs.get("training") is not None:
163+
del _nkwargs["training"]
164+
165+
if _nkwargs.get("transposed") is not None:
166+
del _nkwargs["transposed"]
167+
159168
for name in kwargs.keys():
160169
if name.startswith("_"):
161170
del _nkwargs[name]
@@ -164,7 +173,7 @@ def __init__(self, *args, **kwarg) -> None:
164173
self.initified_weight_mode = False
165174

166175
# prepare quantized weights
167-
def quantize(self):
176+
def quantize_weights(self):
168177
self.initified_weight_mode = True
169178

170179
if self.bias is None:
@@ -182,22 +191,21 @@ def quantize(self):
182191
)
183192
_precision = 2 ** (23 + 1)
184193

185-
###### REFERENCE FROM VCMRS ######
186-
#sf const
194+
###### ADOPT VCMRS IMPLEMENTATION ######
195+
# sf const
187196
sf_const = 48
188197

189-
#N = np.prod(self.weight.shape[1:])
190-
N = np.prod(self.weight.shape) / self.weight.shape[1] # (in, out, kH, kW)
198+
N = np.prod(self.weight.shape) / self.weight.shape[1] # (in, out, kH, kW)
191199
self.N = N
192200
self.factor = np.sqrt(_precision)
193-
#self.sf = 1/6 #precision bits allocation factor
201+
# self.sf = 1/6 #precision bits allocation factor
194202
self.sf = np.sqrt(sf_const / N)
195203

196204
# perform the calculate ion CPU to stabalize the calculation
197205
self.w_sum = self.weight.cpu().abs().sum(axis=[0, 2, 3]).to(self.weight.device)
198-
self.w_sum[self.w_sum == 0] = 1 # prevent divide by 0
206+
self.w_sum[self.w_sum == 0] = 1 # prevent divide by 0
199207

200-
self.fw = (self.factor / self.sf - np.sqrt(N / 12) * 5) / self.w_sum
208+
self.fw = (self.factor / self.sf - np.sqrt(N / 12) * 5) / self.w_sum
201209

202210
# intify weights
203211
self.weight.requires_grad = False # Just make sure
@@ -210,14 +218,14 @@ def quantize(self):
210218
self.bias.requires_grad = False # Just make sure
211219
self.bias.zero_()
212220

213-
###### END OF REFERENCE FROM VCMRS ######
221+
###### END OF THE REFERENCE IMPELEMENTATION OF THE INT CONVS IN VCMRS ######
214222

215223
def integer_transposeconv2d(self, x: torch.Tensor):
216224
_dtype = x.dtype
217225
_cudnn_enabled = torch.backends.cudnn.enabled
218226
torch.backends.cudnn.enabled = False
219227

220-
###### REFERENCE FROM VCMRS ######
228+
###### ADOPT VCMRS IMPLEMENTATION ######
221229
# Calculate factor
222230
fx = 1
223231

@@ -227,17 +235,24 @@ def integer_transposeconv2d(self, x: torch.Tensor):
227235
fx = (self.factor * self.sf - 0.5) / x_max
228236

229237
# intify x
230-
x = torch.round(fx * x)
231-
x = super().forward(x)
238+
out_x = torch.round(fx * x)
239+
out_x = super().forward(out_x)
232240

233241
# x should be all integers
234-
x /= fx * self.fw.view(-1, 1, 1)
235-
x = x.float()
242+
out_x = out_x / (fx * self.fw.to(out_x.device).view(-1, 1, 1))
243+
out_x = out_x.float()
236244

237245
# apply bias in float format
238-
x = (x.permute(0, 2, 3, 1) + self.float_bias).permute(0, 3, 1, 2).contiguous()
246+
out_x = (
247+
(out_x.permute(0, 2, 3, 1) + self.float_bias.to(out_x.device))
248+
.permute(0, 3, 1, 2)
249+
.contiguous()
250+
)
239251

240-
###### REFERENCE FROM VCMRS ######
252+
###### END OF THE REFERENCE IMPELEMENTATION OF THE INT CONVS IN VCMRS ######
241253
torch.backends.cudnn.enabled = _cudnn_enabled
242254

243255
return out_x.to(_dtype)
256+
257+
def transposedconv2d(self, x: torch.Tensor):
258+
return super().forward(x)

compressai_vision/model_wrappers/jde.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pathlib import Path
3232
from typing import Dict, List
3333

34+
import jde
3435
import torch
3536
from jde.models import Darknet
3637
from jde.tracker import matching

0 commit comments

Comments
 (0)