2828# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929
3030
31- import configparser
3231from enum import Enum
3332from pathlib import Path
3433from typing import Dict , List
4039from compressai_vision .registry import register_vision_model
4140
4241from .base_wrapper import BaseWrapper
42+ from .split_squeezes import squeeze_yolox
4343
4444__all__ = [
4545 "yolox_darknet53" ,
@@ -76,7 +76,7 @@ def __init__(self, device: str, **kwargs):
7676 self .conf_thres = kwargs ["conf_thres" ]
7777 self .nms_thres = kwargs ["nms_thres" ]
7878
79- self .supported_split_points = Split_Points
79+ self .squeeze_at_split_enabled = False
8080
8181 exp = get_exp (exp_file = None , exp_name = "yolov3" )
8282
@@ -85,9 +85,10 @@ def __init__(self, device: str, **kwargs):
8585
8686 assert "splits" in kwargs , "Split layer ids must be provided"
8787 self .split_id = str (kwargs ["splits" ]).lower ()
88- if self .split_id == str (self .supported_split_points .Layer13_Single ):
88+
89+ if self .split_id == str (Split_Points .Layer13_Single ):
8990 self .split_layer_list = ["l13" ]
90- elif self .split_id == str (self . supported_split_points .Layer37_Single ):
91+ elif self .split_id == str (Split_Points .Layer37_Single ):
9192 self .split_layer_list = ["l37" ]
9293 else :
9394 raise NotImplementedError
@@ -100,8 +101,12 @@ def __init__(self, device: str, **kwargs):
100101 torch .load (self .model_info ["weights" ], map_location = "cpu" )["model" ],
101102 strict = False ,
102103 )
104+
103105 self .model .to (device ).eval ()
104106
107+ if bool (kwargs ["squeeze_at_split" ]):
108+ self .enable_squeeze_at_split (self .split_id )
109+
105110 self .yolo_fpn = self .model .backbone
106111 self .backbone = self .yolo_fpn .backbone
107112 self .head = self .model .head
@@ -112,11 +117,38 @@ def __init__(self, device: str, **kwargs):
112117
113118 @property
114119 def SPLIT_L13 (self ):
115- return str (self . supported_split_points .Layer13_Single )
120+ return str (Split_Points .Layer13_Single )
116121
117122 @property
118123 def SPLIT_L37 (self ):
119- return str (self .supported_split_points .Layer37_Single )
124+ return str (Split_Points .Layer37_Single )
125+
126+ def enable_squeeze_at_split (self , split_id ):
127+ from torch .hub import load_state_dict_from_url
128+
129+ LIST_OF_SQUEEZE_SUPPORT_SPLITS = [str (Split_Points .Layer13_Single )]
130+
131+ if split_id in LIST_OF_SQUEEZE_SUPPORT_SPLITS :
132+ self .squeeze_at_split_enabled = True
133+ self .squeeze_model = squeeze_yolox .three_convs_at_l13 (
134+ C0 = 256 , C1 = 256 , C2 = 128 , C3 = 128
135+ )
136+
137+ state_dict = load_state_dict_from_url (
138+ self .squeeze_model .address ,
139+ progress = True ,
140+ check_hash = True ,
141+ map_location = self .device ,
142+ )
143+
144+ self .squeeze_model .load_state_dict (state_dict )
145+ self .squeeze_model .to (self .device ).eval ()
146+
147+ else :
148+ self .logger .warning (
149+ f"Squeeze is not available at { split_id } . Currently only available at { LIST_OF_SQUEEZE_SUPPORT_SPLITS } "
150+ )
151+ self .squeeze_at_split_enabled = False
120152
121153 def input_to_features (self , x , device : str ) -> Dict :
122154 """Computes deep features at the intermediate layer(s) all the way from the input"""
@@ -126,9 +158,9 @@ def input_to_features(self, x, device: str) -> Dict:
126158 input_size = tuple (img .shape [2 :])
127159
128160 if self .split_id == self .SPLIT_L13 :
129- output = self ._input_to_feature_at_l13 (img )
161+ output = self ._input_to_feature_at_l13 (img , device )
130162 elif self .split_id == self .SPLIT_L37 :
131- output = self ._input_to_feature_at_l37 (img )
163+ output = self ._input_to_feature_at_l37 (img , device )
132164 else :
133165 self .logger .error (f"Not supported split point { self .split_id } " )
134166 raise NotImplementedError
@@ -143,29 +175,36 @@ def features_to_output(self, x: Dict, device: str):
143175
144176 if self .split_id == self .SPLIT_L13 :
145177 return self ._feature_at_l13_to_output (
146- x ["data" ], x ["org_input_size" ], x ["input_size" ]
178+ x ["data" ], x ["org_input_size" ], x ["input_size" ], device
147179 )
148180 elif self .split_id == self .SPLIT_L37 :
149181 return self ._feature_at_l37_to_output (
150- x ["data" ], x ["org_input_size" ], x ["input_size" ]
182+ x ["data" ], x ["org_input_size" ], x ["input_size" ], device
151183 )
152184 else :
153185 self .logger .error (f"Not supported split points { self .split_id } " )
154186
155187 raise NotImplementedError
156188
157189 @torch .no_grad ()
158- def _input_to_feature_at_l13 (self , x ):
190+ def _input_to_feature_at_l13 (self , x , device ):
159191 """Computes and return feature at layer 13 with leaky relu all the way from the input"""
160192
161193 y = self .backbone .stem (x )
162194 y = self .backbone .dark2 (y )
163- self . features_at_splits [ self . SPLIT_L13 ] = self .backbone .dark3 [0 ](y )
195+ y = self .backbone .dark3 [0 ](y )
164196
197+ if not self .squeeze_at_split_enabled :
198+ self .features_at_splits [self .SPLIT_L13 ] = y
199+ return {"data" : self .features_at_splits }
200+
201+ # Further squeeze
202+ smodel = self .squeeze_model .to (device )
203+ self .features_at_splits [self .SPLIT_L13 ] = smodel .squeeze_ (y )
165204 return {"data" : self .features_at_splits }
166205
167206 @torch .no_grad ()
168- def _input_to_feature_at_l37 (self , x ):
207+ def _input_to_feature_at_l37 (self , x , device ):
169208 """Computes and return feature at layer 37 with 11th residual layer output all the way from the input"""
170209
171210 y = self .backbone .stem (x )
@@ -177,7 +216,7 @@ def _input_to_feature_at_l37(self, x):
177216
178217 @torch .no_grad ()
179218 def _feature_at_l13_to_output (
180- self , x : Dict , org_img_size : Dict , input_img_size : List
219+ self , x : Dict , org_img_size : Dict , input_img_size : List , device
181220 ):
182221 """
183222 performs downstream task using the features from layer 13
@@ -191,8 +230,13 @@ def _feature_at_l13_to_output(
191230 <https://github.com/Megvii-BaseDetection/YOLOX?tab=Apache-2.0-1-ov-file#readme>
192231
193232 """
194-
195233 y = x [self .SPLIT_L13 ]
234+
235+ # Recovery session to expand dimension to original
236+ if self .squeeze_at_split_enabled :
237+ smodel = self .squeeze_model .to (device )
238+ y = smodel .expand_ (y )
239+
196240 for proc_module in self .backbone .dark3 [1 :]:
197241 y = proc_module (y )
198242
@@ -220,7 +264,7 @@ def _feature_at_l13_to_output(
220264
221265 @torch .no_grad ()
222266 def _feature_at_l37_to_output (
223- self , x : Dict , org_img_size : Dict , input_img_size : List
267+ self , x : Dict , org_img_size : Dict , input_img_size : List , device
224268 ):
225269 """
226270 performs downstream task using the features from layer 37
0 commit comments