Skip to content

Commit 17f17b3

Browse files
alex-hhlhoestq
andauthored
support for custom feature encoding/decoding (#7284)
* support for custom feature encoding/decoding * Update src/datasets/features/features.py --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 2049c00 commit 17f17b3

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/datasets/features/features.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,7 +1348,7 @@ def encode_nested_example(schema, obj, level=0):
13481348
return list(obj)
13491349
# Object with special encoding:
13501350
# ClassLabel will convert from string to int, TranslationVariableLanguages does some checks
1351-
elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)):
1351+
elif hasattr(schema, "encode_example"):
13521352
return schema.encode_example(obj) if obj is not None else None
13531353
# Other object should be directly convertible to a native Arrow type (like Translation and Translation)
13541354
return obj
@@ -1399,10 +1399,9 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
13991399
else:
14001400
return decode_nested_example([schema.feature], obj)
14011401
# Object with special decoding:
1402-
elif isinstance(schema, (Audio, Image, Video)):
1402+
elif hasattr(schema, "decode_example") and getattr(schema, "decode", True):
14031403
# we pass the token to read and decode files from private repositories in streaming mode
1404-
if obj is not None and schema.decode:
1405-
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id)
1404+
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) if obj is not None else None
14061405
return obj
14071406

14081407

@@ -1629,7 +1628,9 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False
16291628
elif isinstance(feature, Sequence):
16301629
return require_decoding(feature.feature)
16311630
else:
1632-
return hasattr(feature, "decode_example") and (feature.decode if not ignore_decode_attribute else True)
1631+
return hasattr(feature, "decode_example") and (
1632+
getattr(feature, "decode", True) if not ignore_decode_attribute else True
1633+
)
16331634

16341635

16351636
def require_storage_cast(feature: FeatureType) -> bool:

0 commit comments

Comments
 (0)