@@ -207,7 +207,9 @@ def _prep_features_to_dump(features, n_bits, datacatalog_name):
207207 if n_bits == - 1 :
208208 data_features = features ["data" ]
209209 elif n_bits >= 8 :
210- assert n_bits == 8 , "currently it only supports dumping features in 8 bits"
210+ assert (
211+ n_bits == 8 or n_bits == 16
212+ ), "currently it only supports dumping features in 8 bits or 16 bits"
211213 assert datacatalog_name in list (
212214 MIN_MAX_DATASET .keys ()
213215 ), f"{ datacatalog_name } does not exist in the pre-computed minimum and maximum tables"
@@ -218,7 +220,21 @@ def _prep_features_to_dump(features, n_bits, datacatalog_name):
218220 data .min () >= minv and data .max () <= maxv
219221 ), f"{ data .min ()} should be greater than { minv } and { data .max ()} should be less than { maxv } "
220222 out , _ = min_max_normalization (data , minv , maxv , bitdepth = n_bits )
221- data_features [key ] = out .to (torch .uint8 )
223+
224+ if n_bits <= 8 :
225+ data_features [key ] = out .to (torch .uint8 )
226+ elif n_bits <= 16 :
227+ data_features [key ] = {
228+ "lsb" : torch .bitwise_and (
229+ out .to (torch .int32 ), torch .tensor (0xFF )
230+ ).to (torch .uint8 ),
231+ "msb" : torch .bitwise_and (
232+ torch .bitwise_right_shift (out .to (torch .int32 ), 8 ),
233+ torch .tensor (0xFF ),
234+ ).to (torch .uint8 ),
235+ }
236+ else :
237+ raise NotImplementedError
222238 else :
223239 raise NotImplementedError
224240
@@ -230,15 +246,30 @@ def _post_process_loaded_features(features, n_bits, datacatalog_name):
230246 if n_bits == - 1 :
231247 assert "data" in features
232248 elif n_bits >= 8 :
233- assert n_bits == 8 , "currently it only supports dumping features in 8 bits"
249+ assert (
250+ n_bits == 8 or n_bits == 16
251+ ), "currently it only supports dumping features in 8 bits or 16 bits"
234252 assert datacatalog_name in list (
235253 MIN_MAX_DATASET .keys ()
236254 ), f"{ datacatalog_name } does not exist in the pre-computed minimum and maximum tables"
237255 minv , maxv = MIN_MAX_DATASET [datacatalog_name ]
238256 data_features = {}
239257 for key , data in features ["data" ].items ():
240- out = min_max_inv_normalization (data , minv , maxv , bitdepth = n_bits )
241- data_features [key ] = out .to (torch .float32 )
258+
259+ if n_bits <= 8 :
260+ out = min_max_inv_normalization (data , minv , maxv , bitdepth = n_bits )
261+ data_features [key ] = out .to (torch .float32 )
262+ elif n_bits <= 16 :
263+ lsb_part = data ["lsb" ].to (torch .int32 )
264+ msb_part = torch .bitwise_left_shift (data ["msb" ].to (torch .int32 ), 8 )
265+ recovery = (msb_part + lsb_part ).to (torch .float32 )
266+
267+ out = min_max_inv_normalization (
268+ recovery , minv , maxv , bitdepth = n_bits
269+ )
270+ data_features [key ] = out .to (torch .float32 )
271+ else :
272+ raise NotImplementedError
242273
243274 features ["data" ] = data_features
244275 else :
0 commit comments