Skip to content

Commit a58353c

Browse files
chyomin06fracape
authored andcommitted
[dev] support datatype change at split inference for both image and video pipeline
1 parent 971b760 commit a58353c

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

compressai_vision/pipelines/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
self._create_folder(self.codec_output_dir)
8282
self.init_time_measure()
8383
self.init_complexity_measure()
84+
self.eval()
8485

8586
def init_time_measure(self):
8687
self.elapsed_time = {"nn_part_1": 0, "encode": 0, "decode": 0, "nn_part_2": 0}

compressai_vision/pipelines/split_inference/image_split_inference.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import os
3131
from typing import Dict
3232

33+
import torch
3334
from torch.utils.data import DataLoader
3435
from tqdm import tqdm
3536

@@ -70,6 +71,7 @@ def __init__(
7071
device: Dict,
7172
):
7273
super().__init__(configs, device)
74+
self.datatype = configs["datatype"]
7375

7476
def __call__(
7577
self,
@@ -116,6 +118,9 @@ def __call__(
116118
featureT = self._from_input_to_features(vision_model, d, file_prefix)
117119
self.update_time_elapsed("nn_part_1", (time_measure() - start))
118120

121+
# datatype conversion
122+
# featureT["data"] = {k : v.type(getattr(torch, self.datatype)) for k, v in featureT["data"].items()}
123+
119124
featureT["org_input_size"] = org_img_size
120125

121126
start = time_measure()

compressai_vision/pipelines/split_inference/video_split_inference.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,14 @@ def __call__(
177177
features["data"] = self._feature_tensor_list_to_dict(
178178
self._input_ftensor_buffer
179179
)
180-
# datatype conversion
181-
features["data"] = {k : v.type(getattr(torch, self.datatype)) for k, v in features["data"]}
182180
self._input_ftensor_buffer = []
183181

182+
# datatype conversion
183+
features["data"] = {
184+
k: v.type(getattr(torch, self.datatype))
185+
for k, v in features["data"].items()
186+
}
187+
184188
# Feature Compression
185189
start = time_measure()
186190
res, enc_time_by_module, enc_complexity = self._compress(
@@ -250,9 +254,13 @@ def __call__(
250254

251255
# separate a tensor of each keyword item into a list of tensors
252256
dec_ftensors_list = self._feature_tensor_dict_to_list(dec_features["data"])
253-
assert all([self.datatype in str(d.dtype) for d in dec_ftensors_list[0].values()]), "Output features not of expected datatype"
257+
assert all(
258+
[self.datatype in str(d.dtype) for d in dec_ftensors_list[0].values()]
259+
), "Output features not of expected datatype"
254260

255-
dec_ftensors_list = [{k : v.type(torch.float32) for k, v in d.items()} for d in dec_ftensors_list]
261+
dec_ftensors_list = [
262+
{k: v.type(torch.float32) for k, v in d.items()} for d in dec_ftensors_list
263+
]
256264

257265
assert len(dec_ftensors_list) == len(dataloader), (
258266
f"The number of decoded frames ({len(dec_ftensors_list)}) is not equal "

0 commit comments

Comments
 (0)