Skip to content

Commit 67c9f41

Browse files
chrisr-12fracape
authored andcommitted
[feat] Add IntTransposedConv
1 parent cf5eb21 commit 67c9f41

File tree

2 files changed

+105
-8
lines changed

2 files changed

+105
-8
lines changed

compressai_vision/model_wrappers/detectron2.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# Copyright (c) 2022-2024, InterDigital Communications, Inc
21
# All rights reserved.
32

43
# Redistribution and use in source and binary forms, with or without
@@ -46,7 +45,7 @@
4645
from compressai_vision.registry import register_vision_model
4746

4847
from .base_wrapper import BaseWrapper
49-
from .intconv2d import IntConv2d
48+
from .intconv2d import IntConv2d, IntTransposedConv2d
5049

5150
__all__ = [
5251
"faster_rcnn_X_101_32x8d_FPN_3x",
@@ -189,9 +188,13 @@ def size_divisibility(self):
189188

190189
def replace_conv2d_modules(self, module):
191190
for child_name, child_module in module.named_children():
192-
if type(child_module).__name__ == "Conv2d":
193-
int_conv2d = Conv2d(**child_module.__dict__)
194-
int_conv2d.set_attributes(child_module)
191+
if type(child_module).__name__ in ["Conv2d", "TransposedConv2d"]:
192+
if type(child_module).__name__ == "Conv2d":
193+
int_module = Conv2d(**child_module.__dict__)
194+
int_module.set_attributes(child_module)
195+
else:
196+
int_module = IntTransposedConv2d(**child_module.__dict__)
197+
int_module.set_attributes(child_module)
195198

196199
# Since regular list is used instead of ModuleList
197200
if "fpn_lateral" in child_name or "fpn_output" in child_name:
@@ -201,12 +204,12 @@ def replace_conv2d_modules(self, module):
201204
assert idx in [2, 3, 4, 5]
202205

203206
if "fpn_lateral" in child_name:
204-
module.lateral_convs[3 - (idx - 2)] = int_conv2d
207+
module.lateral_convs[3 - (idx - 2)] = int_module
205208
else:
206209
assert "fpn_output" in child_name
207-
module.output_convs[3 - (idx - 2)] = int_conv2d
210+
module.output_convs[3 - (idx - 2)] = int_module
208211

209-
setattr(module, child_name, int_conv2d)
212+
setattr(module, child_name, int_module)
210213
else:
211214
self.replace_conv2d_modules(child_module)
212215

compressai_vision/model_wrappers/intconv2d.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,97 @@ def conv2d(self, x: torch.Tensor):
147147
self.dilation,
148148
self.groups,
149149
)
150+
151+
152+
class IntTransposedConv2d(torch.nn.ConvTranspose2d):
153+
def __init__(self, *args, **kwarg) -> None:
154+
_nkwargs = copy.deepcopy(kwargs)
155+
156+
del _nkwargs["training"]
157+
del _nkwargs["transposed"]
158+
del _nkwargs["output_padding"]
159+
for name in kwargs.keys():
160+
if name.startswith("_"):
161+
del _nkwargs[name]
162+
163+
super().__init__(*args, **_nkwargs)
164+
self.initified_weight_mode = False
165+
166+
# prepare quantized weights
167+
def quantize(self):
168+
self.initified_weight_mode = True
169+
170+
if self.bias is None:
171+
self.float_bias = torch.zeros(self.out_channels, device=self.weight.device)
172+
else:
173+
self.float_bias = self.bias.detach().clone()
174+
175+
if self.weight.dtype == torch.float32:
176+
_precision = 2 ** (23 + 1)
177+
elif self.weight.dtype == torch.float64:
178+
_precision = 2 ** (52 + 1)
179+
else:
180+
logging.warning(
181+
f"Unsupported dtype {self.weight.dtype}. Behaviour may lead unexpected results."
182+
)
183+
_precision = 2 ** (23 + 1)
184+
185+
###### REFERENCE FROM VCMRS ######
186+
#sf const
187+
sf_const = 48
188+
189+
#N = np.prod(self.weight.shape[1:])
190+
N = np.prod(self.weight.shape) / self.weight.shape[1] # (in, out, kH, kW)
191+
self.N = N
192+
self.factor = np.sqrt(_precision)
193+
#self.sf = 1/6 #precision bits allocation factor
194+
self.sf = np.sqrt(sf_const / N)
195+
196+
# perform the calculate ion CPU to stabalize the calculation
197+
self.w_sum = self.weight.cpu().abs().sum(axis=[0, 2, 3]).to(self.weight.device)
198+
self.w_sum[self.w_sum == 0] = 1 # prevent divide by 0
199+
200+
self.fw = (self.factor / self.sf - np.sqrt(N / 12) * 5) / self.w_sum
201+
202+
# intify weights
203+
self.weight.requires_grad = False # Just make sure
204+
self.weight.copy_(
205+
torch.round(self.weight.detach().clone() * self.fw.view(1, -1, 1, 1))
206+
)
207+
208+
# set bias to 0
209+
if self.bias is not None:
210+
self.bias.requires_grad = False # Just make sure
211+
self.bias.zero_()
212+
213+
###### END OF REFERENCE FROM VCMRS ######
214+
215+
def integer_transposeconv2d(self, x: torch.Tensor):
216+
_dtype = x.dtype
217+
_cudnn_enabled = torch.backends.cudnn.enabled
218+
torch.backends.cudnn.enabled = False
219+
220+
###### REFERENCE FROM VCMRS ######
221+
# Calculate factor
222+
fx = 1
223+
224+
x_abs = x.abs()
225+
x_max = x_abs.max()
226+
if x_max > 0:
227+
fx = (self.factor * self.sf - 0.5) / x_max
228+
229+
# intify x
230+
x = torch.round(fx * x)
231+
x = super().forward(x)
232+
233+
# x should be all integers
234+
x /= fx * self.fw.view(-1, 1, 1)
235+
x = x.float()
236+
237+
# apply bias in float format
238+
x = (x.permute(0, 2, 3, 1) + self.float_bias).permute(0, 3, 1, 2).contiguous()
239+
240+
###### REFERENCE FROM VCMRS ######
241+
torch.backends.cudnn.enabled = _cudnn_enabled
242+
243+
return out_x.to(_dtype)

0 commit comments

Comments
 (0)