Skip to content

Commit 0431f25

Browse files
authored
fix: change datatype of simhash to string, because pyarrow is incompatible with uint64 (#170)
1 parent afac978 commit 0431f25

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

data_juicer/ops/deduplicator/document_simhash_deduplicator.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# https://github.com/bigscience-workshop/data-preparation
33
# --------------------------------------------------------
44

5-
from collections import Counter, defaultdict, deque
5+
from collections import defaultdict, deque
66
from typing import Dict, Set
77

88
import numpy as np
@@ -156,8 +156,8 @@ def compute_hash(self, sample):
156156
f'Unimplemented tokenization method [{self.tokenization}]')
157157

158158
# compute simhash
159-
sample[HashKeys.simhash] = np.uint64(
160-
simhash.compute(map(simhash.unsigned_hash, tokens)))
159+
sample[HashKeys.simhash] = str(
160+
np.uint64(simhash.compute(map(simhash.unsigned_hash, tokens))))
161161
return sample
162162

163163
def process(self, dataset, show_num=0):
@@ -176,25 +176,23 @@ def process(self, dataset, show_num=0):
176176
# find matches
177177
logger.info(f'Start querying {len(dataset)} samples.')
178178
matches = simhash.find_all(
179-
dataset[HashKeys.simhash],
179+
np.uint64(dataset[HashKeys.simhash]),
180180
self.num_blocks,
181181
self.hamming_distance,
182182
)
183183
logger.info(f'Querying done, found {len(matches)} matches.')
184184

185185
# compute hash diff distribution
186186
graph = defaultdict(dict)
187-
dist = Counter()
188187
for x, y in matches:
188+
x = str(x)
189+
y = str(y)
189190
graph[x][y] = graph[y][x] = True
190-
num_diff = num_differing_bits(x, y)
191-
dist[num_diff] += 1
192-
logger.info(f'Hash diff distribution: {dist}')
193-
194-
hash2ids: Dict[int, Set[str]] = defaultdict(set)
195-
hashes: Set[int] = set(dataset[HashKeys.simhash])
196-
hash2cluster: Dict[int, int] = {}
197-
visited: Set[int] = set()
191+
192+
hash2ids: Dict[str, Set[str]] = defaultdict(set)
193+
hashes: Set[str] = set(dataset[HashKeys.simhash])
194+
hash2cluster: Dict[str, int] = {}
195+
visited: Set[str] = set()
198196
cluster_id: int = 0
199197

200198
for sid, hash_val in enumerate(dataset[HashKeys.simhash]):

0 commit comments

Comments
 (0)