Skip to content
Closed

Dev #15

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/static-gh-pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ on:
push:
branches:
- main
pull_request:
branches:
- main
- dev

jobs:
docs_to_gh-pages:
runs-on: ubuntu-latest
Expand Down
3 changes: 2 additions & 1 deletion cfgs/vision_model/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ yolox_darknet53:
conf_thres: 0.001
nms_thres: 0.65
weights: "weights/yolox/darknet53/yolox_darknet.pth"
splits: "l13" #"l37"
splits: "l13" #"l37"
squeeze_at_split: False
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2024 InterDigital Communications, Inc
# Copyright (c) 2025, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -27,20 +27,25 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""cli load_eval functionality
"""

import torch.nn as nn

def main(p):
print("importing fiftyone")
import fiftyone as fo

print("fiftyone imported")
# dataset = fo.load_dataset(p.dataset_name)
print("removing dataset %s from fiftyone" % (p.dataset_name))
if not p.y:
input("press enter to continue.. ")
try:
fo.delete_dataset(p.dataset_name)
except ValueError as e:
print("could not deregister because of", e)
class squeeze_base(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()

self.squeeze_ftensor = None
self.expand_ftensor = None

@property
def address(self):
return "PROVIDE URL"

def squeeze_(self, x):
# You may implement your own
return self.squeeze_ftensor(x)

def expand_(self, x):
# You may implement your own
return self.expand_ftensor(x)
85 changes: 85 additions & 0 deletions compressai_vision/model_wrappers/split_squeezes/squeeze_yolox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2025, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import torch.nn as nn

from .squeeze_base import squeeze_base


# for YOLOX-Darknet53
class three_convs_at_l13(squeeze_base):
def __init__(self, C0, C1, C2, C3):
super().__init__(C0, C1, C2, C3)

self.fw_block = nn.Sequential(
nn.Conv2d(
in_channels=C0, out_channels=C1, kernel_size=3, padding=1, stride=1
),
nn.PReLU(),
nn.Conv2d(
in_channels=C1, out_channels=C2, kernel_size=3, padding=1, stride=2
),
nn.PReLU(),
nn.Conv2d(
in_channels=C2, out_channels=C3, kernel_size=1, padding=0, stride=1
),
nn.SiLU(inplace=True),
)

self.bw_block = nn.Sequential(
nn.Conv2d(
in_channels=C3, out_channels=C2, kernel_size=3, padding=1, stride=1
),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.PReLU(),
nn.Conv2d(
in_channels=C2, out_channels=C1, kernel_size=3, padding=1, stride=1
),
nn.PReLU(),
nn.Conv2d(
in_channels=C1, out_channels=C0, kernel_size=1, padding=0, stride=1
),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
)

@property
def address(self):
return "https://dspub.blob.core.windows.net/compressai-vision/split_squeezes/yolox_darknet53/three_convs_squeeze_at_l13_of_yolox_darknet53-f78179c1.pth"

def squeeze_(self, x):
return self.fw_block(x)

def expand_(self, x):
return self.bw_block(x)

def forward(self, x):
y = self.fw_block(x)
est_x = self.bw_block(y)
return est_x
76 changes: 60 additions & 16 deletions compressai_vision/model_wrappers/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import configparser
from enum import Enum
from pathlib import Path
from typing import Dict, List
Expand All @@ -40,6 +39,7 @@
from compressai_vision.registry import register_vision_model

from .base_wrapper import BaseWrapper
from .split_squeezes import squeeze_yolox

__all__ = [
"yolox_darknet53",
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(self, device: str, **kwargs):
self.conf_thres = kwargs["conf_thres"]
self.nms_thres = kwargs["nms_thres"]

self.supported_split_points = Split_Points
self.squeeze_at_split_enabled = False

exp = get_exp(exp_file=None, exp_name="yolov3")

Expand All @@ -85,9 +85,10 @@ def __init__(self, device: str, **kwargs):

assert "splits" in kwargs, "Split layer ids must be provided"
self.split_id = str(kwargs["splits"]).lower()
if self.split_id == str(self.supported_split_points.Layer13_Single):

if self.split_id == str(Split_Points.Layer13_Single):
self.split_layer_list = ["l13"]
elif self.split_id == str(self.supported_split_points.Layer37_Single):
elif self.split_id == str(Split_Points.Layer37_Single):
self.split_layer_list = ["l37"]
else:
raise NotImplementedError
Expand All @@ -100,8 +101,12 @@ def __init__(self, device: str, **kwargs):
torch.load(self.model_info["weights"], map_location="cpu")["model"],
strict=False,
)

self.model.to(device).eval()

if bool(kwargs["squeeze_at_split"]):
self.enable_squeeze_at_split(self.split_id)

self.yolo_fpn = self.model.backbone
self.backbone = self.yolo_fpn.backbone
self.head = self.model.head
Expand All @@ -112,11 +117,38 @@ def __init__(self, device: str, **kwargs):

@property
def SPLIT_L13(self):
return str(self.supported_split_points.Layer13_Single)
return str(Split_Points.Layer13_Single)

@property
def SPLIT_L37(self):
return str(self.supported_split_points.Layer37_Single)
return str(Split_Points.Layer37_Single)

def enable_squeeze_at_split(self, split_id):
from torch.hub import load_state_dict_from_url

LIST_OF_SQUEEZE_SUPPORT_SPLITS = [str(Split_Points.Layer13_Single)]

if split_id in LIST_OF_SQUEEZE_SUPPORT_SPLITS:
self.squeeze_at_split_enabled = True
self.squeeze_model = squeeze_yolox.three_convs_at_l13(
C0=256, C1=256, C2=128, C3=128
)

state_dict = load_state_dict_from_url(
self.squeeze_model.address,
progress=True,
check_hash=True,
map_location=self.device,
)

self.squeeze_model.load_state_dict(state_dict)
self.squeeze_model.to(self.device).eval()

else:
self.logger.warning(
f"Squeeze is not available at {split_id}. Currently only available at {LIST_OF_SQUEEZE_SUPPORT_SPLITS}"
)
self.squeeze_at_split_enabled = False

def input_to_features(self, x, device: str) -> Dict:
"""Computes deep features at the intermediate layer(s) all the way from the input"""
Expand All @@ -126,9 +158,9 @@ def input_to_features(self, x, device: str) -> Dict:
input_size = tuple(img.shape[2:])

if self.split_id == self.SPLIT_L13:
output = self._input_to_feature_at_l13(img)
output = self._input_to_feature_at_l13(img, device)
elif self.split_id == self.SPLIT_L37:
output = self._input_to_feature_at_l37(img)
output = self._input_to_feature_at_l37(img, device)
else:
self.logger.error(f"Not supported split point {self.split_id}")
raise NotImplementedError
Expand All @@ -143,29 +175,36 @@ def features_to_output(self, x: Dict, device: str):

if self.split_id == self.SPLIT_L13:
return self._feature_at_l13_to_output(
x["data"], x["org_input_size"], x["input_size"]
x["data"], x["org_input_size"], x["input_size"], device
)
elif self.split_id == self.SPLIT_L37:
return self._feature_at_l37_to_output(
x["data"], x["org_input_size"], x["input_size"]
x["data"], x["org_input_size"], x["input_size"], device
)
else:
self.logger.error(f"Not supported split points {self.split_id}")

raise NotImplementedError

@torch.no_grad()
def _input_to_feature_at_l13(self, x):
def _input_to_feature_at_l13(self, x, device):
"""Computes and return feature at layer 13 with leaky relu all the way from the input"""

y = self.backbone.stem(x)
y = self.backbone.dark2(y)
self.features_at_splits[self.SPLIT_L13] = self.backbone.dark3[0](y)
y = self.backbone.dark3[0](y)

if not self.squeeze_at_split_enabled:
self.features_at_splits[self.SPLIT_L13] = y
return {"data": self.features_at_splits}

# Further squeeze
smodel = self.squeeze_model.to(device)
self.features_at_splits[self.SPLIT_L13] = smodel.squeeze_(y)
return {"data": self.features_at_splits}

@torch.no_grad()
def _input_to_feature_at_l37(self, x):
def _input_to_feature_at_l37(self, x, device):
"""Computes and return feature at layer 37 with 11th residual layer output all the way from the input"""

y = self.backbone.stem(x)
Expand All @@ -177,7 +216,7 @@ def _input_to_feature_at_l37(self, x):

@torch.no_grad()
def _feature_at_l13_to_output(
self, x: Dict, org_img_size: Dict, input_img_size: List
self, x: Dict, org_img_size: Dict, input_img_size: List, device
):
"""
performs downstream task using the features from layer 13
Expand All @@ -191,8 +230,13 @@ def _feature_at_l13_to_output(
<https://github.com/Megvii-BaseDetection/YOLOX?tab=Apache-2.0-1-ov-file#readme>

"""

y = x[self.SPLIT_L13]

# Recovery session to expand dimension to original
if self.squeeze_at_split_enabled:
smodel = self.squeeze_model.to(device)
y = smodel.expand_(y)

for proc_module in self.backbone.dark3[1:]:
y = proc_module(y)

Expand Down Expand Up @@ -220,7 +264,7 @@ def _feature_at_l13_to_output(

@torch.no_grad()
def _feature_at_l37_to_output(
self, x: Dict, org_img_size: Dict, input_img_size: List
self, x: Dict, org_img_size: Dict, input_img_size: List, device
):
"""
performs downstream task using the features from layer 37
Expand Down
41 changes: 36 additions & 5 deletions compressai_vision/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def _prep_features_to_dump(features, n_bits, datacatalog_name):
if n_bits == -1:
data_features = features["data"]
elif n_bits >= 8:
assert n_bits == 8, "currently it only supports dumping features in 8 bits"
assert (
n_bits == 8 or n_bits == 16
), "currently it only supports dumping features in 8 bits or 16 bits"
assert datacatalog_name in list(
MIN_MAX_DATASET.keys()
), f"{datacatalog_name} does not exist in the pre-computed minimum and maximum tables"
Expand All @@ -218,7 +220,21 @@ def _prep_features_to_dump(features, n_bits, datacatalog_name):
data.min() >= minv and data.max() <= maxv
), f"{data.min()} should be greater than {minv} and {data.max()} should be less than {maxv}"
out, _ = min_max_normalization(data, minv, maxv, bitdepth=n_bits)
data_features[key] = out.to(torch.uint8)

if n_bits <= 8:
data_features[key] = out.to(torch.uint8)
elif n_bits <= 16:
data_features[key] = {
"lsb": torch.bitwise_and(
out.to(torch.int32), torch.tensor(0xFF)
).to(torch.uint8),
"msb": torch.bitwise_and(
torch.bitwise_right_shift(out.to(torch.int32), 8),
torch.tensor(0xFF),
).to(torch.uint8),
}
else:
raise NotImplementedError
else:
raise NotImplementedError

Expand All @@ -230,15 +246,30 @@ def _post_process_loaded_features(features, n_bits, datacatalog_name):
if n_bits == -1:
assert "data" in features
elif n_bits >= 8:
assert n_bits == 8, "currently it only supports dumping features in 8 bits"
assert (
n_bits == 8 or n_bits == 16
), "currently it only supports dumping features in 8 bits or 16 bits"
assert datacatalog_name in list(
MIN_MAX_DATASET.keys()
), f"{datacatalog_name} does not exist in the pre-computed minimum and maximum tables"
minv, maxv = MIN_MAX_DATASET[datacatalog_name]
data_features = {}
for key, data in features["data"].items():
out = min_max_inv_normalization(data, minv, maxv, bitdepth=n_bits)
data_features[key] = out.to(torch.float32)

if n_bits <= 8:
out = min_max_inv_normalization(data, minv, maxv, bitdepth=n_bits)
data_features[key] = out.to(torch.float32)
elif n_bits <= 16:
lsb_part = data["lsb"].to(torch.int32)
msb_part = torch.bitwise_left_shift(data["msb"].to(torch.int32), 8)
recovery = (msb_part + lsb_part).to(torch.float32)

out = min_max_inv_normalization(
recovery, minv, maxv, bitdepth=n_bits
)
data_features[key] = out.to(torch.float32)
else:
raise NotImplementedError

features["data"] = data_features
else:
Expand Down
Loading