Skip to content

Commit a0f78a3

Browse files
committed
update for api test & api consistency
1 parent 4c42367 commit a0f78a3

File tree

10 files changed

+32
-22
lines changed

10 files changed

+32
-22
lines changed

configs/process/dedup_api.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
HashDeduplicator:
2+
hash_func: 'md5'
3+
CCNetDeduplicator:
4+
bit_length: 64 # should be a multiple of 8

configs/process/filter_api.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,4 @@ FineWebEduFilter:
77
min_score: 0
88
max_score: 100
99
scorer_args:
10-
model_name: 'HuggingFaceTB/fineweb-edu-classifier'
11-
device: 'cuda:4'
10+
model_name: 'HuggingFaceTB/fineweb-edu-classifier'

dataflow/core/process/deduplicator.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ def __init__(self, args=None):
2323
def __call__(self, dataset):
2424
init_len = len(dataset)
2525
labels = self.dedup_func(dataset)
26-
if isinstance(dataset.dataset, Dataset):
27-
def filter_by_labels(example, index):
28-
return labels[index] == 1
29-
dataset.dataset = dataset.dataset.filter(filter_by_labels, with_indices=True)
30-
deduped_dataset = dataset
31-
else:
32-
deduped_dataset = dataset.filter(labels)
33-
print(f'Implemented {self.dedupliactor_name}. Data Number: {init_len} -> {len(deduped_dataset)}')
34-
return deduped_dataset
26+
# if isinstance(dataset.dataset, Dataset):
27+
# def filter_by_labels(example, index):
28+
# return labels[index] == 1
29+
# dataset.dataset = dataset.dataset.filter(filter_by_labels, with_indices=True)
30+
# deduped_dataset = dataset
31+
# else:
32+
# deduped_dataset = dataset.filter(labels)
33+
# print(f'Implemented {self.dedupliactor_name}. Data Number: {init_len} -> {len(deduped_dataset)}')
34+
return labels
3535

3636
class ImageDeduplicator(Deduplicator):
3737

dataflow/process/text/deduplicators/ccnet_deduplicator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ def dedup_func(self, dataset):
2727
hash_values.append(hash_value)
2828
# print(json.dumps({"ccnet_hash_values": hash_values}))
2929

30-
return json.dumps({"ccnet_hash_values": hash_values})
30+
return {"ccnet_hash_values": hash_values}
3131

3232

dataflow/process/text/deduplicators/hash_deduplicator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ def dedup_func(self, dataset):
3131
hash_value = self._compute_hash(text)
3232
hash_values.append(hash_value)
3333
# print(json.dumps({"hash_values": hash_values}))
34-
return json.dumps({"exact_hash_values": hash_values})
34+
return {"exact_hash_values": hash_values}

dataflow/process/text/deduplicators/minhash_deduplicator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def dedup_func(self, dataset):
3636
result = lsh.query(minhash)
3737
hash_values.append(result)
3838
# print(json.dumps({"hash_values": hash_values}))
39-
return json.dumps({"minhash_hash_values": hash_values})
39+
return {"minhash_hash_values": hash_values}
4040

4141

4242

dataflow/process/text/deduplicators/ngramhash_deduplicator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def dedup_func(self, dataset):
3333
hash_value = set(self._compute_hash(ngram) for ngram in ngrams)
3434
hash_values.append(hash_value)
3535
# print(json.dumps({"hash_values": hash_values}))
36-
return json.dumps({"ngram_hash_values": hash_values})
36+
return {"ngram_hash_values": hash_values}
3737

3838

3939

dataflow/process/text/deduplicators/sem_deduplicator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,4 @@ def dedup_func(self, dataset):
8888
embeddings = get_text_embedding(texts, self.tokenizer, self.model, self.device)
8989
embeddings = normalize(torch.tensor(embeddings), dim=1)
9090
# print(json.dumps({"embeddings": embeddings.tolist()}))
91-
return json.dumps({"semhash_embeddings": embeddings.tolist()})
91+
return {"semhash_embeddings": embeddings.tolist()}

dataflow/process/text/deduplicators/simhash_deduplicator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ def get_similarity(simhash, another_simhash):
3232
simhash = Simhash(text, f=self.fingerprint_size)
3333
simhashes.append(simhash)
3434
# print(json.dumps({"hash_values": [simhash.value for simhash in simhashes]}))
35-
return json.dumps({"simhash_values": [simhash.value for simhash in simhashes]})
35+
return {"simhash_values": [simhash.value for simhash in simhashes]}
3636

3737

dataflow/utils/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,11 @@ def filter():
286286
# dataset.dump(save_path)
287287
result[recorder] = True
288288
result = result.tolist()
289+
save_path = cfg['save_path']
290+
from bitarray import bitarray
291+
ba = bitarray(result)
292+
with open(save_path, 'wb') as f:
293+
ba.tofile(f)
289294
print(json.dumps({"bool": result}))
290295

291296
def refine():
@@ -299,7 +304,7 @@ def refine():
299304
if isinstance(cfg.yaml, str):
300305
with open(cfg.yaml, 'r') as f:
301306
cfg.yaml = yaml.safe_load(f) # 解析成字典
302-
307+
303308
for scorer_name, args in cfg.yaml.items():
304309
if "num_workers" in cfg:
305310
args["num_workers"] = cfg.num_workers
@@ -329,7 +334,7 @@ def deduplicate():
329334
if isinstance(cfg.yaml, str):
330335
with open(cfg.yaml, 'r') as f:
331336
cfg.yaml = yaml.safe_load(f) # 解析成字典
332-
337+
result = []
333338
for scorer_name, args in cfg.yaml.items():
334339
if "num_workers" in cfg:
335340
args["num_workers"] = cfg.num_workers
@@ -342,9 +347,11 @@ def deduplicate():
342347
dataset_dict[processor.data_type] = datasets
343348
else:
344349
datasets = dataset_dict[processor.data_type]
345-
processed_dataset = processor(datasets)
346-
dataset_dict[processor.data_type] = processed_dataset
347-
print(processed_dataset)
350+
result.append(processor(datasets))
351+
# dataset_dict[processor.data_type] = processed_dataset
352+
save_path = cfg['save_path']
353+
with open(save_path, 'w') as f:
354+
json.dump(result, f)
348355

349356

350357

0 commit comments

Comments
 (0)