Skip to content

Commit cf5eb21

Browse files
chyomin06fracape
authored andcommitted
[feat] support intconv2d for Faster RCNN & Mask RCNN
1 parent 19b7ca8 commit cf5eb21

File tree

5 files changed

+161
-66
lines changed

5 files changed

+161
-66
lines changed

cfgs/vision_model/default.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,28 @@ faster_rcnn_R_50_FPN_3x:
66
model_path_prefix: ${..model_root_path}
77
cfg: "models/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
88
weights: "weights/detectron2/COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl"
9+
integer_conv_weight: False
910
splits : "r2" #, "c2" or "fpn"
1011

1112
faster_rcnn_X_101_32x8d_FPN_3x:
1213
model_path_prefix: ${..model_root_path}
1314
cfg: "models/detectron2/configs/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"
1415
weights: "weights/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
16+
integer_conv_weight: False
1517
splits : "fpn" #, "c2" or "r2"
1618

1719
mask_rcnn_R_50_FPN_3x:
1820
model_path_prefix: ${..model_root_path}
1921
cfg: "models/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
2022
weights: "weights/detectron2/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"
23+
integer_conv_weight: False
2124
splits : "r2" #, "c2" or "fpn"
2225

2326
mask_rcnn_X_101_32x8d_FPN_3x:
2427
model_path_prefix: ${..model_root_path}
2528
cfg: "models/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"
2629
weights: "weights/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl"
30+
integer_conv_weight: False
2731
splits : "fpn" #, "c2" or "r2"
2832

2933
panoptic_rcnn_R_101_FPN_3x:

compressai_vision/model_wrappers/detectron2.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30+
import re
3031
from enum import Enum
3132
from pathlib import Path
3233
from typing import Dict, List
@@ -45,6 +46,7 @@
4546
from compressai_vision.registry import register_vision_model
4647

4748
from .base_wrapper import BaseWrapper
49+
from .intconv2d import IntConv2d
4850

4951
__all__ = [
5052
"faster_rcnn_X_101_32x8d_FPN_3x",
@@ -58,6 +60,51 @@
5860
root_path = thisdir.joinpath("../..")
5961

6062

63+
# reference Conv2D from detectron2/detercon2/layers/wrappers.py
64+
class Conv2d(IntConv2d):
65+
def __init__(self, *args, **kwargs) -> None:
66+
"""
67+
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
68+
69+
Args:
70+
norm (nn.Module, optional): a normalization layer
71+
activation (callable(Tensor) -> Tensor): a callable activation function
72+
73+
It assumes that norm layer is used before activation.
74+
"""
75+
76+
norm = kwargs.pop("norm", None)
77+
activation = kwargs.pop("activation", None)
78+
super().__init__(*args, **kwargs)
79+
80+
self.norm = norm
81+
self.activation = activation
82+
83+
def set_attributes(self, module):
84+
85+
if hasattr(module, "norm"):
86+
self.norm = module.norm
87+
88+
if hasattr(module, "activation"):
89+
self.activation = module.activation
90+
91+
if hasattr(module, "bias"):
92+
self.bias = module.bias
93+
94+
def forward(self, x: torch.Tensor):
95+
if not self.initified_weight_mode:
96+
x = self.conv2d(x)
97+
else:
98+
x = self.integer_conv2d(x)
99+
100+
if self.norm is not None:
101+
x = self.norm(x)
102+
if self.activation is not None:
103+
x = self.activation(x)
104+
105+
return x
106+
107+
61108
class Split_Points(Enum):
62109
def __str__(self):
63110
return str(self.value)
@@ -79,16 +126,28 @@ def __init__(self, device: str, **kwargs):
79126
else kwargs["model_path_prefix"]
80127
)
81128
self._cfg.merge_from_file(f"{_path_prefix}/{kwargs['cfg']}")
129+
_integer_conv_weight = bool(kwargs["integer_conv_weight"])
82130

83-
self.model = build_model(self._cfg).to(device).eval()
131+
self.model = build_model(self._cfg)
132+
self.replace_conv2d_modules(self.model)
133+
self.model = self.model.to(device).eval()
134+
135+
DetectionCheckpointer(self.model).load(f"{_path_prefix}/{kwargs['weights']}")
136+
137+
for param in self.model.parameters():
138+
param.requires_grad = False
139+
140+
# must be called after loading weights to a model
141+
if _integer_conv_weight:
142+
self.mode = self.quantize_weights(self.model)
84143

85144
self.backbone = self.model.backbone
86145
self.top_block = self.model.backbone.top_block
87146
self.proposal_generator = self.model.proposal_generator
88147
self.roi_heads = self.model.roi_heads
89148
self.postprocess = self.model._postprocess
90-
DetectionCheckpointer(self.model).load(f"{_path_prefix}/{kwargs['weights']}")
91149

150+
# to be used for printing info logs
92151
self.model_info = {"cfg": kwargs["cfg"], "weights": kwargs["weights"]}
93152

94153
self.supported_split_points = Split_Points
@@ -128,6 +187,38 @@ def SPLIT_R2(self):
128187
def size_divisibility(self):
129188
return self.backbone.size_divisibility
130189

190+
def replace_conv2d_modules(self, module):
191+
for child_name, child_module in module.named_children():
192+
if type(child_module).__name__ == "Conv2d":
193+
int_conv2d = Conv2d(**child_module.__dict__)
194+
int_conv2d.set_attributes(child_module)
195+
196+
# Since regular list is used instead of ModuleList
197+
if "fpn_lateral" in child_name or "fpn_output" in child_name:
198+
idx = re.findall(r"\d", child_name)
199+
assert len(idx) == 1
200+
idx = int(idx[0])
201+
assert idx in [2, 3, 4, 5]
202+
203+
if "fpn_lateral" in child_name:
204+
module.lateral_convs[3 - (idx - 2)] = int_conv2d
205+
else:
206+
assert "fpn_output" in child_name
207+
module.output_convs[3 - (idx - 2)] = int_conv2d
208+
209+
setattr(module, child_name, int_conv2d)
210+
else:
211+
self.replace_conv2d_modules(child_module)
212+
213+
@staticmethod
214+
def quantize_weights(model):
215+
for _, m in model.named_modules():
216+
if type(m).__name__ == "Conv2d":
217+
# print(f"Module name: {name} and type {type(m).__name__}")
218+
m.quantize_weights()
219+
220+
return model
221+
131222
def input_resize(self, images: List):
132223
return ImageList.from_tensors(images, self.size_divisibility)
133224

compressai_vision/model_wrappers/intconv_wrapper.py renamed to compressai_vision/model_wrappers/intconv2d.py

Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -27,57 +27,27 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30+
import copy
3031
import logging
3132

3233
import numpy as np
3334
import torch
34-
from torch import nn
35-
36-
37-
class IntConv2dWrapper(nn.Conv2d):
38-
def __init__(
39-
self,
40-
in_channels: int,
41-
out_channels: int,
42-
kernel_size,
43-
stride=1,
44-
padding=0,
45-
dilation=1,
46-
groups: int = 1,
47-
bias: bool = True,
48-
padding_mode: str = "zeros",
49-
device=None,
50-
dtype=None,
51-
) -> None:
52-
super().__init__(
53-
in_channels,
54-
out_channels,
55-
kernel_size,
56-
stride,
57-
padding,
58-
dilation,
59-
groups,
60-
bias,
61-
padding_mode,
62-
device,
63-
dtype,
64-
)
65-
self.initified_weight_mode = False
35+
from torch.nn import functional as F
6636

67-
"""
68-
def _set_mode(mode):
69-
global _precision, _high_precision, _mode
7037

71-
if mode == 'none':
72-
_precision = 0
73-
elif mode == 'float32':
74-
_precision = 2**(23+1)
75-
elif mode == 'float64':
76-
_precision = 2**(52+1)
38+
class IntConv2d(torch.nn.Conv2d):
39+
def __init__(self, *args, **kwargs) -> None:
40+
_nkwargs = copy.deepcopy(kwargs)
7741

78-
_mode = mode
79-
torch.backends.cudnn.enabled = mode=='none'
80-
"""
42+
del _nkwargs["training"]
43+
del _nkwargs["transposed"]
44+
del _nkwargs["output_padding"]
45+
for name in kwargs.keys():
46+
if name.startswith("_"):
47+
del _nkwargs[name]
48+
49+
super().__init__(*args, **_nkwargs)
50+
self.initified_weight_mode = False
8151

8252
def quantize_weights(self):
8353
self.initified_weight_mode = True
@@ -97,7 +67,7 @@ def quantize_weights(self):
9767
)
9868
_precision = 2 ** (23 + 1)
9969

100-
###### REFERENCE FROM VCMRMS ######
70+
###### REFERENCE FROM VCMRS ######
10171
# sf const
10272
sf_const = 48
10373

@@ -124,18 +94,14 @@ def quantize_weights(self):
12494
self.bias.requires_grad = False # Just make sure
12595
self.bias.zero_()
12696

127-
###### END OF REFERENCE FROM VCMRMS ######
128-
129-
def forward(self, x: torch.Tensor):
130-
if not self.initified_weight_mode:
131-
return super().forward(x)
97+
###### END OF REFERENCE FROM VCMRS ######
13298

99+
def integer_conv2d(self, x: torch.Tensor):
133100
_dtype = x.dtype
134101
_cudnn_enabled = torch.backends.cudnn.enabled
135102
torch.backends.cudnn.enabled = False
136103

137-
###### REFERENCE FROM VCMRMS ######
138-
104+
###### REFERENCE FROM VCMRS ######
139105
# Calculate factor
140106
fx = 1
141107

@@ -145,17 +111,39 @@ def forward(self, x: torch.Tensor):
145111
fx = (self.factor * self.sf - 0.5) / x_max
146112

147113
# intify x
148-
x = torch.round(fx * x)
149-
x = super().forward(x)
114+
out_x = torch.round(fx * x)
115+
116+
out_x = F.conv2d(
117+
out_x,
118+
self.weight,
119+
self.bias,
120+
self.stride,
121+
self.padding,
122+
self.dilation,
123+
self.groups,
124+
)
150125

151126
# x should be all integers
152-
x /= fx * self.fw.view(-1, 1, 1)
153-
x = x.float()
127+
out_x = out_x / (fx * self.fw.view(-1, 1, 1)).float()
154128

155129
# apply bias in float format
156-
x = (x.permute(0, 2, 3, 1) + self.float_bias).permute(0, 3, 1, 2).contiguous()
157-
###### REFERENCE FROM VCMRMS ######
158-
130+
out_x = (
131+
(out_x.permute(0, 2, 3, 1) + self.float_bias)
132+
.permute(0, 3, 1, 2)
133+
.contiguous()
134+
)
135+
###### REFERENCE FROM VCMRS ######
159136
torch.backends.cudnn.enabled = _cudnn_enabled
160137

161-
return x.to(_dtype)
138+
return out_x.to(_dtype)
139+
140+
def conv2d(self, x: torch.Tensor):
141+
return F.conv2d(
142+
x,
143+
self.weight,
144+
self.bias,
145+
self.stride,
146+
self.padding,
147+
self.dilation,
148+
self.groups,
149+
)

compressai_vision/model_wrappers/jde.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(self, device: str, **kwargs):
8888
self.model_configs["frame_rate"] / 30.0 * self.model_configs["track_buffer"]
8989
)
9090

91-
integer_conv_weight = bool(kwargs["integer_conv_weight"])
91+
_integer_conv_weight = bool(kwargs["integer_conv_weight"])
9292

9393
assert "splits" in kwargs, "Split layer ids must be provided"
9494
self.split_layer_list = kwargs["splits"]
@@ -106,7 +106,7 @@ def __init__(self, device: str, **kwargs):
106106
param.requires_grad = False
107107

108108
# must be called after loading weights to a model
109-
if integer_conv_weight:
109+
if _integer_conv_weight:
110110
self.darknet = self.quantize_weights(self.darknet)
111111

112112
self.kalman_filter = KalmanFilter()
@@ -128,10 +128,10 @@ def reset(self):
128128

129129
@staticmethod
130130
def quantize_weights(model):
131-
132131
for module_def, module in zip(model.module_defs, model.module_list):
133132
if module_def["type"] == "convolutional":
134133
for m in module:
134+
# print(type(m).__name__)
135135
if type(m).__name__ == "IntConv2dWrapper":
136136
m.quantize_weights()
137137

compressai_vision/model_wrappers/jde_lowlevel.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30+
import torch
3031
import torch.nn as nn
3132
from jde.models import EmptyLayer, Upsample, YOLOLayer
3233

33-
from .intconv_wrapper import IntConv2dWrapper
34+
from .intconv2d import IntConv2d
3435

3536
try:
3637
from jde.utils.syncbn import SyncBN
@@ -40,6 +41,17 @@
4041
batch_norm = nn.BatchNorm2d
4142

4243

44+
class IntConv2dWrapper(IntConv2d):
45+
def __init__(self, *args, **kwargs) -> None:
46+
super().__init__(*args, **kwargs)
47+
48+
def forward(self, x: torch.Tensor):
49+
if not self.initified_weight_mode:
50+
return self.conv2d(x)
51+
52+
return self.integer_conv2d(x)
53+
54+
4355
def create_modules(module_defs, device: str):
4456
"""
4557
Constructs module list of layer blocks from module configuration in module_defs

0 commit comments

Comments
 (0)