Skip to content

Commit 67f41d3

Browse files
committed
add: allow hashing for int tensors
1 parent 4bcee01 commit 67f41d3

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

mldaikon/proxy_wrapper/hash.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,25 @@ def tensor_hash(x: Tensor, with_parallel: bool = True, with_cuda: bool = True) -
8888
if hasattr(x, "_mldaikon_tensor_hash"):
8989
return x._mldaikon_tensor_hash
9090
if with_parallel:
91-
assert x.dtype in [
91+
if x.dtype in [
9292
torch.float32,
9393
torch.float64,
9494
torch.bfloat16,
9595
torch.float16,
9696
torch.float,
97-
]
98-
99-
# Convert the floating-point tensor to an integer representation
100-
x = (x * 1e8).to(torch.int64)
97+
]:
98+
# Convert the floating-point tensor to an integer representation
99+
x = (x * 1e8).to(torch.int64)
100+
else:
101+
assert x.dtype in [
102+
torch.int32,
103+
torch.int64,
104+
torch.uint8,
105+
torch.int8,
106+
torch.int16,
107+
], f"Unsupported tensor type for hashing, expected either int or float, got {x.dtype}"
108+
# Ensure the tensor is of integer type
109+
x = x.to(torch.int64)
101110

102111
# Ensure the tensor is of integer type
103112
assert x.dtype == torch.int64

0 commit comments

Comments
 (0)