Skip to content

Commit 2242f8a

Browse files
jianwensongfracape
authored andcommitted
[fix] preprocessing features prior to nn part2 before dumping
1 parent 8550a64 commit 2242f8a

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

compressai_vision/pipelines/base.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import logging
3333
import os
3434

35-
from collections import OrderedDict
3635
from enum import Enum
3736
from pathlib import Path
3837
from typing import Callable, Dict, Tuple
@@ -50,7 +49,11 @@
5049
min_max_normalization,
5150
)
5251
from compressai_vision.model_wrappers import BaseWrapper
53-
from compressai_vision.utils import FileLikeHasher, freeze_zip_timestamps
52+
from compressai_vision.utils import (
53+
FileLikeHasher,
54+
contiguous_features,
55+
freeze_zip_timestamps,
56+
)
5457

5558

5659
class Parts(Enum):
@@ -404,11 +407,7 @@ def _from_features_to_output(
404407
self.logger.debug(f"dumping features prior to nn part2 in: {feature_dir}")
405408

406409
# [TODO] align with nn_task_part1 dump features
407-
features_to_dump = {
408-
k: v
409-
for k, v in sorted(x.items(), key=lambda kv: str(kv[0]))
410-
if not k.startswith("file")
411-
}
410+
features_to_dump = contiguous_features(x)
412411

413412
with freeze_zip_timestamps():
414413
if dump_feature_hash:

compressai_vision/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from . import dataio, git, pip, system
3131
from .external_exec import get_max_num_cpus
32-
from .hash import FileLikeHasher, freeze_zip_timestamps
32+
from .hash import FileLikeHasher, contiguous_features, freeze_zip_timestamps
3333
from .misc import dict_sum, dl_to_ld, ld_to_dl, metric_tracking, time_measure, to_cpu
3434

3535
__all__ = [
@@ -46,4 +46,5 @@
4646
"ld_to_dl",
4747
"FileLikeHasher",
4848
"freeze_zip_timestamps",
49+
"contiguous_features",
4950
]

compressai_vision/utils/hash.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@
3131
import hashlib
3232
import zipfile
3333

34+
from collections import OrderedDict
35+
from collections.abc import Mapping, Sequence
3436
from contextlib import contextmanager
3537
from typing import Tuple
3638

39+
import torch
40+
3741

3842
class FileLikeHasher:
3943
def __init__(self, fn, algo: str = "md5"):
@@ -70,3 +74,23 @@ def _patched(self, *args, **kwargs):
7074
yield
7175
finally:
7276
zipfile.ZipInfo.__init__ = _orig_init
77+
78+
79+
def contiguous_features(obj):
80+
if isinstance(obj, torch.Tensor):
81+
return obj.to("cpu").contiguous().clone()
82+
83+
if isinstance(obj, Mapping):
84+
return OrderedDict(
85+
(k, contiguous_features(v))
86+
for k, v in sorted(obj.items(), key=lambda item: str(item[0]))
87+
if not str(k).startswith("file")
88+
)
89+
90+
if isinstance(obj, set):
91+
return tuple(sorted(obj, key=str))
92+
93+
if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)):
94+
return type(obj)(contiguous_features(v) for v in obj)
95+
96+
return obj

0 commit comments

Comments
 (0)