Skip to content

Commit cf0dbc1

Browse files
committed
[fix] bug - 16 bits quantized feature dump and load
1 parent efb7fec commit cf0dbc1

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

compressai_vision/pipelines/base.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def _prep_features_to_dump(features, n_bits, datacatalog_name):
207207
if n_bits == -1:
208208
data_features = features["data"]
209209
elif n_bits >= 8:
210-
assert n_bits == 8, "currently it only supports dumping features in 8 bits"
210+
assert (
211+
n_bits == 8 or n_bits == 16
212+
), "currently it only supports dumping features in 8 bits or 16 bits"
211213
assert datacatalog_name in list(
212214
MIN_MAX_DATASET.keys()
213215
), f"{datacatalog_name} does not exist in the pre-computed minimum and maximum tables"
@@ -218,7 +220,21 @@ def _prep_features_to_dump(features, n_bits, datacatalog_name):
218220
data.min() >= minv and data.max() <= maxv
219221
), f"{data.min()} should be greater than {minv} and {data.max()} should be less than {maxv}"
220222
out, _ = min_max_normalization(data, minv, maxv, bitdepth=n_bits)
221-
data_features[key] = out.to(torch.uint8)
223+
224+
if n_bits <= 8:
225+
data_features[key] = out.to(torch.uint8)
226+
elif n_bits <= 16:
227+
data_features[key] = {
228+
"lsb": torch.bitwise_and(
229+
out.to(torch.int32), torch.tensor(0xFF)
230+
).to(torch.uint8),
231+
"msb": torch.bitwise_and(
232+
torch.bitwise_right_shift(out.to(torch.int32), 8),
233+
torch.tensor(0xFF),
234+
).to(torch.uint8),
235+
}
236+
else:
237+
raise NotImplementedError
222238
else:
223239
raise NotImplementedError
224240

@@ -230,15 +246,30 @@ def _post_process_loaded_features(features, n_bits, datacatalog_name):
230246
if n_bits == -1:
231247
assert "data" in features
232248
elif n_bits >= 8:
233-
assert n_bits == 8, "currently it only supports dumping features in 8 bits"
249+
assert (
250+
n_bits == 8 or n_bits == 16
251+
), "currently it only supports dumping features in 8 bits or 16 bits"
234252
assert datacatalog_name in list(
235253
MIN_MAX_DATASET.keys()
236254
), f"{datacatalog_name} does not exist in the pre-computed minimum and maximum tables"
237255
minv, maxv = MIN_MAX_DATASET[datacatalog_name]
238256
data_features = {}
239257
for key, data in features["data"].items():
240-
out = min_max_inv_normalization(data, minv, maxv, bitdepth=n_bits)
241-
data_features[key] = out.to(torch.float32)
258+
259+
if n_bits <= 8:
260+
out = min_max_inv_normalization(data, minv, maxv, bitdepth=n_bits)
261+
data_features[key] = out.to(torch.float32)
262+
elif n_bits <= 16:
263+
lsb_part = data["lsb"].to(torch.int32)
264+
msb_part = torch.bitwise_left_shift(data["msb"].to(torch.int32), 8)
265+
recovery = (msb_part + lsb_part).to(torch.float32)
266+
267+
out = min_max_inv_normalization(
268+
recovery, minv, maxv, bitdepth=n_bits
269+
)
270+
data_features[key] = out.to(torch.float32)
271+
else:
272+
raise NotImplementedError
242273

243274
features["data"] = data_features
244275
else:

0 commit comments

Comments
 (0)