File tree Expand file tree Collapse file tree 1 file changed +43
-0
lines changed
Expand file tree Collapse file tree 1 file changed +43
-0
lines changed Original file line number Diff line number Diff line change 4343else :
4444 HAVE_NUMPY = True
4545
46+ try :
47+ import pandas
48+ except ImportError :
49+ HAVE_PANDAS = False
50+ else :
51+ HAVE_PANDAS = True
52+
53+ try :
54+ import torch
55+ except ImportError :
56+ HAVE_PYTORCH = False
57+ else :
58+ HAVE_PYTORCH = True
59+
60+ try :
61+ import tensorflow
62+ except ImportError :
63+ HAVE_TENSORFLOW = False
64+ else :
65+ HAVE_TENSORFLOW = True
66+
4667__all__ = (
4768 "hash_function" ,
4869 "hash_object" ,
@@ -564,5 +585,27 @@ def bytes_repr_numpy(obj: numpy.ndarray, cache: Cache) -> Iterator[bytes]:
564585 else :
565586 yield obj .tobytes (order = "C" )
566587
588+ if HAVE_PYTORCH :
589+
590+ @register_serializer (torch .Tensor )
591+ def bytes_repr_torch (obj : torch .Tensor , cache : Cache ) -> Iterator [bytes ]:
592+ yield f"{ obj .__class__ .__module__ } { obj .__class__ .__name__ } :" .encode ()
593+ yield from bytes_repr_numpy (obj .numpy (), cache )
594+
595+
596+ if HAVE_TENSORFLOW :
597+
598+ @register_serializer (tensorflow .Tensor )
599+ def bytes_repr_tensorflow (obj : tensorflow .Tensor , cache : Cache ) -> Iterator [bytes ]:
600+ yield f"{ obj .__class__ .__module__ } { obj .__class__ .__name__ } :" .encode ()
601+ yield from bytes_repr_numpy (obj .numpy (), cache )
602+
603+
604+ if HAVE_PANDAS :
605+
606+ @register_serializer (pandas .DataFrame )
607+ def bytes_repr_pandas (obj : pandas .DataFrame , cache : Cache ) -> Iterator [bytes ]:
608+ yield f"{ obj .__class__ .__module__ } { obj .__class__ .__name__ } :" .encode ()
609+ yield from bytes_repr_numpy (obj .to_numpy (), cache )
567610
568611NUMPY_CHUNK_LEN = 8192
You can’t perform that action at this time.
0 commit comments