Skip to content

Commit 8550a64

Browse files
jianwensongfracape
authored andcommitted
[feat] hash for features prior to nn part2
1 parent 5f36ce2 commit 8550a64

File tree

4 files changed

+125
-0
lines changed

4 files changed

+125
-0
lines changed

cfgs/pipeline/split_inference.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ codec:
3636
nn_task_part2:
3737
dump_results: False
3838
output_results_dir: "${codec.output_dir}/output_results"
39+
dump_features: False
40+
feature_dir: "${..output_dir_root}/features_pre_nn_part2/${dataset.datacatalog}/${dataset.config.dataset_name}"
41+
dump_features_hash: False
42+
hash_format: md5
3943
conformance:
4044
save_conformance_files: False
4145
subsample_ratio: 9

compressai_vision/pipelines/base.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import logging
3333
import os
3434

35+
from collections import OrderedDict
3536
from enum import Enum
3637
from pathlib import Path
3738
from typing import Callable, Dict, Tuple
@@ -49,6 +50,7 @@
4950
min_max_normalization,
5051
)
5152
from compressai_vision.model_wrappers import BaseWrapper
53+
from compressai_vision.utils import FileLikeHasher, freeze_zip_timestamps
5254

5355

5456
class Parts(Enum):
@@ -349,6 +351,12 @@ def _from_features_to_output(
349351
"""performs the inference of the 2nd part of the NN model"""
350352
output_results_dir = self.configs["nn_task_part2"].output_results_dir
351353

354+
seq_name = (
355+
seq_name
356+
if seq_name is not None
357+
else os.path.splitext(os.path.basename(x.get("file_name", "")))[0]
358+
)
359+
352360
results_file = f"{output_results_dir}/{seq_name}{self._output_ext}"
353361

354362
assert "data" in x
@@ -374,6 +382,44 @@ def _from_features_to_output(
374382
for k, v in zip(vision_model.split_layer_list, x["data"].values())
375383
}
376384

385+
if (
386+
self.configs["nn_task_part2"].dump_features
387+
or self.configs["nn_task_part2"].dump_features_hash
388+
):
389+
feature_dir = self.configs["nn_task_part2"].feature_dir
390+
self._create_folder(feature_dir)
391+
392+
dump_feature_hash = self.configs["nn_task_part2"].dump_features_hash
393+
hash_format = self.configs["nn_task_part2"].hash_format
394+
395+
feature_output_ext = (
396+
f".{hash_format}" if dump_feature_hash else self._output_ext
397+
)
398+
path = f"{feature_dir}/{seq_name}{feature_output_ext}"
399+
400+
features_file = (
401+
FileLikeHasher(path, hash_format) if dump_feature_hash else path
402+
)
403+
404+
self.logger.debug(f"dumping features prior to nn part2 in: {feature_dir}")
405+
406+
# [TODO] align with nn_task_part1 dump features
407+
features_to_dump = {
408+
k: v
409+
for k, v in sorted(x.items(), key=lambda kv: str(kv[0]))
410+
if not k.startswith("file")
411+
}
412+
413+
with freeze_zip_timestamps():
414+
if dump_feature_hash:
415+
torch.save(features_to_dump, features_file, pickle_protocol=4)
416+
else:
417+
with open(features_file, "wb") as f:
418+
torch.save(features_to_dump, f, pickle_protocol=4)
419+
420+
if hasattr(features_file, "close"):
421+
features_file.close()
422+
377423
results = vision_model.features_to_output(x, self.device_nn_part2)
378424
if self.configs["nn_task_part2"].dump_results:
379425
self._create_folder(output_results_dir)

compressai_vision/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from . import dataio, git, pip, system
3131
from .external_exec import get_max_num_cpus
32+
from .hash import FileLikeHasher, freeze_zip_timestamps
3233
from .misc import dict_sum, dl_to_ld, ld_to_dl, metric_tracking, time_measure, to_cpu
3334

3435
__all__ = [
@@ -43,4 +44,6 @@
4344
"dict_sum",
4445
"dl_to_ld",
4546
"ld_to_dl",
47+
"FileLikeHasher",
48+
"freeze_zip_timestamps",
4649
]

compressai_vision/utils/hash.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) 2022-2024, InterDigital Communications, Inc
2+
# All rights reserved.
3+
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted (subject to the limitations in the disclaimer
6+
# below) provided that the following conditions are met:
7+
8+
# * Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
# * Neither the name of InterDigital Communications, Inc nor the names of its
14+
# contributors may be used to endorse or promote products derived from this
15+
# software without specific prior written permission.
16+
17+
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
18+
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
19+
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
20+
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
21+
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
25+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
26+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
27+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
28+
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
30+
31+
import hashlib
32+
import zipfile
33+
34+
from contextlib import contextmanager
35+
from typing import Tuple
36+
37+
38+
class FileLikeHasher:
39+
def __init__(self, fn, algo: str = "md5"):
40+
self._h = hashlib.new(algo)
41+
self._fn = fn
42+
self._nbytes = 0
43+
44+
def write(self, byts):
45+
self._h.update(byts)
46+
self._nbytes += len(byts)
47+
return len(byts)
48+
49+
def flush(self):
50+
pass
51+
52+
def close(self):
53+
with open(self._fn, "w") as f:
54+
f.write(self._h.hexdigest())
55+
f.write("\n")
56+
57+
58+
@contextmanager
59+
def freeze_zip_timestamps(
60+
fixed: Tuple[int, int, int, int, int, int] = (1980, 1, 1, 0, 0, 0),
61+
):
62+
_orig_init = zipfile.ZipInfo.__init__
63+
64+
def _patched(self, *args, **kwargs):
65+
_orig_init(self, *args, **kwargs)
66+
self.date_time = fixed # ZIP fixed time
67+
68+
zipfile.ZipInfo.__init__ = _patched
69+
try:
70+
yield
71+
finally:
72+
zipfile.ZipInfo.__init__ = _orig_init

0 commit comments

Comments
 (0)