Skip to content

Commit 2a167b9

Browse files
committed
Merge branch 'API' of github.com:Open-DataFlow/DataFlow-Eval-Process into API
2 parents 73dd364 + 014081a commit 2a167b9

File tree

8 files changed

+48
-12
lines changed

8 files changed

+48
-12
lines changed

dataflow/process/text/deduplicators/ccnet_deduplicator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def dedup_func(self, dataset):
2525
text = str(sample[dataset.keys]).encode('utf-8')
2626
hash_value = self._compute_hash(text)
2727
hash_values.append(hash_value)
28-
print(json.dumps({"hash_values": hash_values}))
29-
return hash_values
28+
# print(json.dumps({"ccnet_hash_values": hash_values}))
29+
30+
return json.dumps({"ccnet_hash_values": hash_values})
3031

3132

dataflow/process/text/deduplicators/hash_deduplicator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ def dedup_func(self, dataset):
3030

3131
hash_value = self._compute_hash(text)
3232
hash_values.append(hash_value)
33-
print(json.dumps({"hash_values": hash_values}))
34-
return json.dumps({"hash_values": hash_values})
33+
# print(json.dumps({"hash_values": hash_values}))
34+
return json.dumps({"exact_hash_values": hash_values})

dataflow/process/text/deduplicators/minhash_deduplicator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def dedup_func(self, dataset):
3535
minhash = self.create_minhash(text)
3636
result = lsh.query(minhash)
3737
hash_values.append(result)
38-
print(json.dumps({"hash_values": hash_values}))
39-
return json.dumps({"hash_values": hash_values})
38+
# print(json.dumps({"hash_values": hash_values}))
39+
return json.dumps({"minhash_hash_values": hash_values})
4040

4141

4242

dataflow/process/text/deduplicators/ngramhash_deduplicator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def dedup_func(self, dataset):
3232
ngrams = [text[i*gram_length:(i+1)*gram_length] for i in range(self.n_gram)]
3333
hash_value = set(self._compute_hash(ngram) for ngram in ngrams)
3434
hash_values.append(hash_value)
35-
print(json.dumps({"hash_values": hash_values}))
36-
return json.dumps({"hash_values": hash_values})
35+
# print(json.dumps({"hash_values": hash_values}))
36+
return json.dumps({"ngram_hash_values": hash_values})
3737

3838

3939

dataflow/process/text/deduplicators/sem_deduplicator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,5 @@ def dedup_func(self, dataset):
8787
# Compute embeddings for the dataset texts
8888
embeddings = get_text_embedding(texts, self.tokenizer, self.model, self.device)
8989
embeddings = normalize(torch.tensor(embeddings), dim=1)
90-
print(json.dumps({"embeddings": embeddings.tolist()}))
91-
return json.dumps({"embeddings": embeddings.tolist()})
90+
# print(json.dumps({"embeddings": embeddings.tolist()}))
91+
return json.dumps({"semhash_embeddings": embeddings.tolist()})

dataflow/process/text/deduplicators/simhash_deduplicator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_similarity(simhash, another_simhash):
3131
text = str(sample[dataset.keys])
3232
simhash = Simhash(text, f=self.fingerprint_size)
3333
simhashes.append(simhash)
34-
print(json.dumps({"hash_values": [simhash.value for simhash in simhashes]}))
35-
return json.dumps({"hash_values": [simhash.value for simhash in simhashes]})
34+
# print(json.dumps({"hash_values": [simhash.value for simhash in simhashes]}))
35+
return json.dumps({"simhash_values": [simhash.value for simhash in simhashes]})
3636

3737

dataflow/utils/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,35 @@ def refine():
317317
save_path = cfg['save_path']
318318
for dataset in dataset_dict.values():
319319
dataset.dump(save_path)
320+
321+
def deduplicate():
322+
from ..config import api_init_config
323+
from dataflow.data import DataFlowDSDict
324+
from dataflow.utils.registry import FORMATTER_REGISTRY
325+
from dataflow.core import ScoreRecord
326+
cfg = api_init_config()
327+
dataset_dict = DataFlowDSDict()
328+
329+
if isinstance(cfg.yaml, str):
330+
with open(cfg.yaml, 'r') as f:
331+
cfg.yaml = yaml.safe_load(f) # 解析成字典
332+
333+
for scorer_name, args in cfg.yaml.items():
334+
if "num_workers" in cfg:
335+
args["num_workers"] = cfg.num_workers
336+
if "model_cache_path" in cfg:
337+
args["model_cache_dir"] = cfg.model_cache_path
338+
processor = get_processor(scorer_name, args)
339+
if processor.data_type not in dataset_dict.keys():
340+
formatter = FORMATTER_REGISTRY.get('TextFormatter')(cfg['data'], cfg['key'], cfg['sft_single_round'], cfg['sft_multi_round'], cfg['RLHF'])
341+
datasets = formatter.load_dataset()
342+
dataset_dict[processor.data_type] = datasets
343+
else:
344+
datasets = dataset_dict[processor.data_type]
345+
processed_dataset = processor(datasets)
346+
dataset_dict[processor.data_type] = processed_dataset
347+
print(processed_dataset)
348+
349+
350+
351+

deduplicator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from dataflow.utils.utils import deduplicate
2+
3+
deduplicate()

0 commit comments

Comments
 (0)