Skip to content

Commit 7924471

Browse files
chyomin06fracape
authored andcommitted
[feat] support split squeeze natural bottlenck case for YOLOX Darknet53 at l13
1 parent cf0dbc1 commit 7924471

File tree

4 files changed

+198
-17
lines changed

4 files changed

+198
-17
lines changed

cfgs/vision_model/default.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,5 @@ yolox_darknet53:
4646
conf_thres: 0.001
4747
nms_thres: 0.65
4848
weights: "weights/yolox/darknet53/yolox_darknet.pth"
49-
splits: "l13" #"l37"
49+
splits: "l13" #"l37"
50+
squeeze_at_split: False
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2025, InterDigital Communications, Inc
2+
# All rights reserved.
3+
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted (subject to the limitations in the disclaimer
6+
# below) provided that the following conditions are met:
7+
8+
# * Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
# * Neither the name of InterDigital Communications, Inc nor the names of its
14+
# contributors may be used to endorse or promote products derived from this
15+
# software without specific prior written permission.
16+
17+
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
18+
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
19+
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
20+
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
21+
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
25+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
26+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
27+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
28+
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
30+
31+
import torch.nn as nn
32+
33+
34+
class squeeze_base(nn.Module):
35+
def __init__(self, *args, **kwargs):
36+
super().__init__()
37+
38+
self.squeeze_ftensor = None
39+
self.expand_ftensor = None
40+
41+
@property
42+
def address(self):
43+
return "PROVIDE URL"
44+
45+
def squeeze_(self, x):
46+
# You may implement your own
47+
return self.squeeze_ftensor(x)
48+
49+
def expand_(self, x):
50+
# You may implement your own
51+
return self.expand_ftensor(x)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) 2025, InterDigital Communications, Inc
2+
# All rights reserved.
3+
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted (subject to the limitations in the disclaimer
6+
# below) provided that the following conditions are met:
7+
8+
# * Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
# * Neither the name of InterDigital Communications, Inc nor the names of its
14+
# contributors may be used to endorse or promote products derived from this
15+
# software without specific prior written permission.
16+
17+
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
18+
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
19+
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
20+
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
21+
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
25+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
26+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
27+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
28+
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
30+
31+
import torch.nn as nn
32+
33+
from .squeeze_base import squeeze_base
34+
35+
36+
# for YOLOX-Darknet53
37+
class three_convs_at_l13(squeeze_base):
38+
def __init__(self, C0, C1, C2, C3):
39+
super().__init__(C0, C1, C2, C3)
40+
41+
self.fw_block = nn.Sequential(
42+
nn.Conv2d(
43+
in_channels=C0, out_channels=C1, kernel_size=3, padding=1, stride=1
44+
),
45+
nn.PReLU(),
46+
nn.Conv2d(
47+
in_channels=C1, out_channels=C2, kernel_size=3, padding=1, stride=2
48+
),
49+
nn.PReLU(),
50+
nn.Conv2d(
51+
in_channels=C2, out_channels=C3, kernel_size=1, padding=0, stride=1
52+
),
53+
nn.SiLU(inplace=True),
54+
)
55+
56+
self.bw_block = nn.Sequential(
57+
nn.Conv2d(
58+
in_channels=C3, out_channels=C2, kernel_size=3, padding=1, stride=1
59+
),
60+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
61+
nn.PReLU(),
62+
nn.Conv2d(
63+
in_channels=C2, out_channels=C1, kernel_size=3, padding=1, stride=1
64+
),
65+
nn.PReLU(),
66+
nn.Conv2d(
67+
in_channels=C1, out_channels=C0, kernel_size=1, padding=0, stride=1
68+
),
69+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
70+
)
71+
72+
@property
73+
def address(self):
74+
return "https://dspub.blob.core.windows.net/compressai-vision/split_squeezes/yolox_darknet53/three_convs_squeeze_at_l13_of_yolox_darknet53-f78179c1.pth"
75+
76+
def squeeze_(self, x):
77+
return self.fw_block(x)
78+
79+
def expand_(self, x):
80+
return self.bw_block(x)
81+
82+
def forward(self, x):
83+
y = self.fw_block(x)
84+
est_x = self.bw_block(y)
85+
return est_x

compressai_vision/model_wrappers/yolox.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

3030

31-
import configparser
3231
from enum import Enum
3332
from pathlib import Path
3433
from typing import Dict, List
@@ -40,6 +39,7 @@
4039
from compressai_vision.registry import register_vision_model
4140

4241
from .base_wrapper import BaseWrapper
42+
from .split_squeezes import squeeze_yolox
4343

4444
__all__ = [
4545
"yolox_darknet53",
@@ -76,7 +76,7 @@ def __init__(self, device: str, **kwargs):
7676
self.conf_thres = kwargs["conf_thres"]
7777
self.nms_thres = kwargs["nms_thres"]
7878

79-
self.supported_split_points = Split_Points
79+
self.squeeze_at_split_enabled = False
8080

8181
exp = get_exp(exp_file=None, exp_name="yolov3")
8282

@@ -85,9 +85,10 @@ def __init__(self, device: str, **kwargs):
8585

8686
assert "splits" in kwargs, "Split layer ids must be provided"
8787
self.split_id = str(kwargs["splits"]).lower()
88-
if self.split_id == str(self.supported_split_points.Layer13_Single):
88+
89+
if self.split_id == str(Split_Points.Layer13_Single):
8990
self.split_layer_list = ["l13"]
90-
elif self.split_id == str(self.supported_split_points.Layer37_Single):
91+
elif self.split_id == str(Split_Points.Layer37_Single):
9192
self.split_layer_list = ["l37"]
9293
else:
9394
raise NotImplementedError
@@ -100,8 +101,12 @@ def __init__(self, device: str, **kwargs):
100101
torch.load(self.model_info["weights"], map_location="cpu")["model"],
101102
strict=False,
102103
)
104+
103105
self.model.to(device).eval()
104106

107+
if bool(kwargs["squeeze_at_split"]):
108+
self.enable_squeeze_at_split(self.split_id)
109+
105110
self.yolo_fpn = self.model.backbone
106111
self.backbone = self.yolo_fpn.backbone
107112
self.head = self.model.head
@@ -112,11 +117,38 @@ def __init__(self, device: str, **kwargs):
112117

113118
@property
114119
def SPLIT_L13(self):
115-
return str(self.supported_split_points.Layer13_Single)
120+
return str(Split_Points.Layer13_Single)
116121

117122
@property
118123
def SPLIT_L37(self):
119-
return str(self.supported_split_points.Layer37_Single)
124+
return str(Split_Points.Layer37_Single)
125+
126+
def enable_squeeze_at_split(self, split_id):
127+
from torch.hub import load_state_dict_from_url
128+
129+
LIST_OF_SQUEEZE_SUPPORT_SPLITS = [str(Split_Points.Layer13_Single)]
130+
131+
if split_id in LIST_OF_SQUEEZE_SUPPORT_SPLITS:
132+
self.squeeze_at_split_enabled = True
133+
self.squeeze_model = squeeze_yolox.three_convs_at_l13(
134+
C0=256, C1=256, C2=128, C3=128
135+
)
136+
137+
state_dict = load_state_dict_from_url(
138+
self.squeeze_model.address,
139+
progress=True,
140+
check_hash=True,
141+
map_location=self.device,
142+
)
143+
144+
self.squeeze_model.load_state_dict(state_dict)
145+
self.squeeze_model.to(self.device).eval()
146+
147+
else:
148+
self.logger.warning(
149+
f"Squeeze is not available at {split_id}. Currently only available at {LIST_OF_SQUEEZE_SUPPORT_SPLITS}"
150+
)
151+
self.squeeze_at_split_enabled = False
120152

121153
def input_to_features(self, x, device: str) -> Dict:
122154
"""Computes deep features at the intermediate layer(s) all the way from the input"""
@@ -126,9 +158,9 @@ def input_to_features(self, x, device: str) -> Dict:
126158
input_size = tuple(img.shape[2:])
127159

128160
if self.split_id == self.SPLIT_L13:
129-
output = self._input_to_feature_at_l13(img)
161+
output = self._input_to_feature_at_l13(img, device)
130162
elif self.split_id == self.SPLIT_L37:
131-
output = self._input_to_feature_at_l37(img)
163+
output = self._input_to_feature_at_l37(img, device)
132164
else:
133165
self.logger.error(f"Not supported split point {self.split_id}")
134166
raise NotImplementedError
@@ -143,29 +175,36 @@ def features_to_output(self, x: Dict, device: str):
143175

144176
if self.split_id == self.SPLIT_L13:
145177
return self._feature_at_l13_to_output(
146-
x["data"], x["org_input_size"], x["input_size"]
178+
x["data"], x["org_input_size"], x["input_size"], device
147179
)
148180
elif self.split_id == self.SPLIT_L37:
149181
return self._feature_at_l37_to_output(
150-
x["data"], x["org_input_size"], x["input_size"]
182+
x["data"], x["org_input_size"], x["input_size"], device
151183
)
152184
else:
153185
self.logger.error(f"Not supported split points {self.split_id}")
154186

155187
raise NotImplementedError
156188

157189
@torch.no_grad()
158-
def _input_to_feature_at_l13(self, x):
190+
def _input_to_feature_at_l13(self, x, device):
159191
"""Computes and return feature at layer 13 with leaky relu all the way from the input"""
160192

161193
y = self.backbone.stem(x)
162194
y = self.backbone.dark2(y)
163-
self.features_at_splits[self.SPLIT_L13] = self.backbone.dark3[0](y)
195+
y = self.backbone.dark3[0](y)
164196

197+
if not self.squeeze_at_split_enabled:
198+
self.features_at_splits[self.SPLIT_L13] = y
199+
return {"data": self.features_at_splits}
200+
201+
# Further squeeze
202+
smodel = self.squeeze_model.to(device)
203+
self.features_at_splits[self.SPLIT_L13] = smodel.squeeze_(y)
165204
return {"data": self.features_at_splits}
166205

167206
@torch.no_grad()
168-
def _input_to_feature_at_l37(self, x):
207+
def _input_to_feature_at_l37(self, x, device):
169208
"""Computes and return feature at layer 37 with 11th residual layer output all the way from the input"""
170209

171210
y = self.backbone.stem(x)
@@ -177,7 +216,7 @@ def _input_to_feature_at_l37(self, x):
177216

178217
@torch.no_grad()
179218
def _feature_at_l13_to_output(
180-
self, x: Dict, org_img_size: Dict, input_img_size: List
219+
self, x: Dict, org_img_size: Dict, input_img_size: List, device
181220
):
182221
"""
183222
performs downstream task using the features from layer 13
@@ -191,8 +230,13 @@ def _feature_at_l13_to_output(
191230
<https://github.com/Megvii-BaseDetection/YOLOX?tab=Apache-2.0-1-ov-file#readme>
192231
193232
"""
194-
195233
y = x[self.SPLIT_L13]
234+
235+
# Recovery session to expand dimension to original
236+
if self.squeeze_at_split_enabled:
237+
smodel = self.squeeze_model.to(device)
238+
y = smodel.expand_(y)
239+
196240
for proc_module in self.backbone.dark3[1:]:
197241
y = proc_module(y)
198242

@@ -220,7 +264,7 @@ def _feature_at_l13_to_output(
220264

221265
@torch.no_grad()
222266
def _feature_at_l37_to_output(
223-
self, x: Dict, org_img_size: Dict, input_img_size: List
267+
self, x: Dict, org_img_size: Dict, input_img_size: List, device
224268
):
225269
"""
226270
performs downstream task using the features from layer 37

0 commit comments

Comments
 (0)