Skip to content

Commit 233470d

Browse files
committed
Registered serializer for common classes of additional array-like objects
1 parent d674840 commit 233470d

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

pydra/utils/hash.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,27 @@
4343
else:
4444
HAVE_NUMPY = True
4545

46+
try:
47+
import pandas
48+
except ImportError:
49+
HAVE_PANDAS = False
50+
else:
51+
HAVE_PANDAS = True
52+
53+
try:
54+
import torch
55+
except ImportError:
56+
HAVE_PYTORCH = False
57+
else:
58+
HAVE_PYTORCH = True
59+
60+
try:
61+
import tensorflow
62+
except ImportError:
63+
HAVE_TENSORFLOW = False
64+
else:
65+
HAVE_TENSORFLOW = True
66+
4667
__all__ = (
4768
"hash_function",
4869
"hash_object",
@@ -564,5 +585,27 @@ def bytes_repr_numpy(obj: numpy.ndarray, cache: Cache) -> Iterator[bytes]:
564585
else:
565586
yield obj.tobytes(order="C")
566587

588+
if HAVE_PYTORCH:
589+
590+
@register_serializer(torch.Tensor)
591+
def bytes_repr_torch(obj: torch.Tensor, cache: Cache) -> Iterator[bytes]:
592+
yield f"{obj.__class__.__module__}{obj.__class__.__name__}:".encode()
593+
yield from bytes_repr_numpy(obj.numpy(), cache)
594+
595+
596+
if HAVE_TENSORFLOW:
597+
598+
@register_serializer(tensorflow.Tensor)
599+
def bytes_repr_tensorflow(obj: tensorflow.Tensor, cache: Cache) -> Iterator[bytes]:
600+
yield f"{obj.__class__.__module__}{obj.__class__.__name__}:".encode()
601+
yield from bytes_repr_numpy(obj.numpy(), cache)
602+
603+
604+
if HAVE_PANDAS:
605+
606+
@register_serializer(pandas.DataFrame)
607+
def bytes_repr_pandas(obj: pandas.DataFrame, cache: Cache) -> Iterator[bytes]:
608+
yield f"{obj.__class__.__module__}{obj.__class__.__name__}:".encode()
609+
yield from bytes_repr_numpy(obj.to_numpy(), cache)
567610

568611
NUMPY_CHUNK_LEN = 8192

0 commit comments

Comments
 (0)