Skip to content

Commit ac278a0

Browse files
committed
remove unused code.
1 parent 8d7d886 commit ac278a0

File tree

5 files changed

+33
-214
lines changed

5 files changed

+33
-214
lines changed

fibber/benchmark/benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def __init__(self,
127127
ce_gpu_id=ce_gpu_id,
128128
bert_clf_enable_sem=bert_clf_enable_sem,
129129
bert_clf_enable_lmag=bert_clf_enable_lmag,
130-
enable_ce_similarity=False,
131-
enable_gpt2_perplexity=False
130+
enable_ce_similarity=True,
131+
enable_gpt2_perplexity=True
132132
)
133133

134134
if customized_clf:

fibber/metrics/gpt2_perplexity_metric.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,30 +81,43 @@ def _get_ppl(self, sentences):
8181
ppl = ppl.detach().cpu().numpy()
8282
return ppl
8383

84-
def measure_batch(self, origin, paraphrase_list, data_record=None, paraphrase_field="text0"):
84+
def measure_batch(self, origin, paraphrase_list, data_record=None, paraphrase_field="text0",
85+
use_ratio=False):
8586
"""Measure the metric on a batch of paraphrase_list.
8687
8788
Args:
8889
origin (str): the original text.
8990
paraphrase_list (list): a set of paraphrase_list.
9091
data_record (dict): the corresponding data record of original text.
9192
paraphrase_field (str): the field name to paraphrase.
93+
use_ratio (bool): returns the perplexity ratio.
9294
9395
Returns:
9496
(list): a list containing the USE similarity metric for each paraphrase.
9597
"""
96-
ppls = self._get_ppl([origin] + paraphrase_list)
97-
res = ppls[1:] / ppls[0]
98+
if use_ratio:
99+
ppls = self._get_ppl([origin] + paraphrase_list)
100+
res = ppls[1:] / ppls[0]
101+
else:
102+
ppls = self._get_ppl(paraphrase_list)
98103
return [float(x) for x in res]
99104

100-
def measure_example(self, origin, paraphrase, data_record=None, paraphrase_field="text0"):
105+
def measure_example(self, origin, paraphrase, data_record=None, paraphrase_field="text0",
106+
use_ratio=False):
101107
"""Compute the perplexity ratio.
102108
103109
Args:
104110
origin (str): original text.
105111
paraphrase (str): paraphrased text.
106112
data_record: ignored.
107113
paraphrase_field: ignored.
114+
use_ratio (bool): returns the perplexity ratio.
115+
108116
"""
109-
ppl = self._get_ppl([origin, paraphrase])
110-
return float(ppl[1] / ppl[0])
117+
if use_ratio:
118+
ppl = self._get_ppl([origin, paraphrase])
119+
res = float(ppl[1] / ppl[0])
120+
else:
121+
res = float(self._get_ppl([paraphrase])[0])
122+
123+
return res

fibber/paraphrase_strategies/asrs_strategy.py

Lines changed: 12 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from fibber import log
1010
from fibber.metrics.bert_lm_utils import get_lm
11-
from fibber.paraphrase_strategies.asrs_utils_text_parser import TextParser
1211
from fibber.paraphrase_strategies.asrs_utils_wpe import get_wordpiece_emb
1312
from fibber.paraphrase_strategies.strategy_base import StrategyBase
1413

@@ -132,7 +131,7 @@ def ppl_criteria_score(origin, paraphrases, ppl_metric, ppl_weight):
132131
"""
133132
if ppl_weight == 0:
134133
return np.zeros(len(paraphrases), dtype="float32")
135-
ppl_ratio = ppl_metric.measure_batch(origin, paraphrases)
134+
ppl_ratio = ppl_metric.measure_batch(origin, paraphrases, use_ratio=True)
136135
return -ppl_weight * (np.maximum(np.asarray(ppl_ratio) - 1, 0) ** 2)
137136

138137

@@ -272,11 +271,9 @@ class ASRSStrategy(StrategyBase):
272271
("burnin_criteria_schedule", str, "1", ("the schedule decides how strict the criteria is "
273272
"used. options are [linear, 0, 1].")),
274273
("seed_option", str, "origin", ("the option for seed sentences in generation. "
275-
"choose from [origin, auto, dynamic_len].")),
274+
"choose from [origin, dynamic_len].")),
276275
("dynamic_len_min", int, -3, "change length min."),
277276
("dynamic_len_max", int, 3, "change length max."),
278-
("split_sentence", str, "auto", "split paragraph to sentence. options are [0, 1, auto]."),
279-
("stanza_port", int, 9000, "stanza port"),
280277
("lm_option", str, "finetune", "choose from [pretrain, finetune, adv]."),
281278
("lm_steps", int, 5000, "lm training steps."),
282279
("clf_weight", float, 3, "weight for the clf score in the criteria."),
@@ -304,7 +301,7 @@ def fit(self, trainset):
304301
self._sim_metric = self._metric_bundle.get_metric(
305302
self._strategy_config["sim_metric"])
306303
self._clf_metric = self._metric_bundle.get_target_classifier()
307-
self._ppl_metric = self._metric_bundle.get_metric("GPT2PerplexityMetric")
304+
self._ppl_metric = self._metric_bundle.get_metric("BertPerplexityMetric")
308305

309306
# load word piece embeddings.
310307
wpe = get_wordpiece_emb(self._output_dir, self._dataset_name, trainset, self._device)
@@ -332,13 +329,6 @@ def fit(self, trainset):
332329
else:
333330
assert 0
334331

335-
# load text parser
336-
if (self._strategy_config["seed_option"] != "origin"
337-
or self._strategy_config["split_sentence"] != "0"):
338-
self._text_parser = TextParser(self._strategy_config["stanza_port"])
339-
else:
340-
self._text_parser = None
341-
342332
self._stats = {
343333
"all": 0,
344334
"accept": 0
@@ -351,14 +341,6 @@ def _parallel_sequential_generation(self, original_text, seed, batch_size, burni
351341
batch_tensor = torch.tensor(
352342
[self._tokenizer.convert_tokens_to_ids(seq)] * batch_size).to(self._device)
353343
seq_len = [len(seq)] * batch_size
354-
elif self._strategy_config["seed_option"] == "auto":
355-
seeds = self._text_parser.phrase_level_shuffle(seed, batch_size)
356-
seq = [self._tokenizer.tokenize(x) for x in seeds]
357-
seq_len = ([len(x) + 2 for x in seq])
358-
max_len = max(seq_len)
359-
seq = [["[CLS]"] + x + ["[SEP]"] + ["[PAD]"] * (max_len - len(x) - 2) for x in seq]
360-
batch_tensor = torch.tensor(
361-
[self._tokenizer.convert_tokens_to_ids(x) for x in seq]).to(self._device)
362344
elif self._strategy_config["seed_option"] == "dynamic_len":
363345
seq_raw = self._tokenizer.tokenize(seed)
364346
seq = []
@@ -536,7 +518,7 @@ def paraphrase_example(self, data_record, field_name, n):
536518
self._bert_lm = self._bert_lms[data_record["label"]]
537519
self._bert_lm.to(self._device)
538520

539-
clipped_text = " ".join(data_record[field_name].split()[:200])
521+
clipped_text = data_record[field_name]
540522
clipped_text = process_text(clipped_text, PRE_PROCESSING_PATTERN)
541523
batch_size = self._strategy_config["batch_size"]
542524

@@ -550,53 +532,14 @@ def paraphrase_example(self, data_record, field_name, n):
550532
burnin_steps = self._strategy_config["burnin_steps"]
551533
sampling_steps = self._strategy_config["sampling_steps"]
552534

553-
if self._strategy_config["split_sentence"] == "0":
554-
batch = self._parallel_sequential_generation(
555-
clipped_text,
556-
clipped_text if "seed" not in data_record else data_record["seed"],
557-
batch_size if id != n_batches - 1 else last_batch_size,
558-
burnin_steps,
559-
sampling_steps,
560-
field_name, data_record)
561-
sentences += batch
562-
elif self._strategy_config["split_sentence"] in ["1", "auto"]:
563-
splitted_text_ori = self._text_parser.split_paragraph_to_sentences(
564-
clipped_text)
565-
566-
if self._strategy_config["split_sentence"] == "auto":
567-
splitted_text = []
568-
current_text = ""
569-
for s in splitted_text_ori:
570-
current_text += " " + s
571-
if len(current_text.split()) > AUTO_SENTENCE_LEN_THRESHOLD:
572-
splitted_text.append(current_text)
573-
current_text = ""
574-
if len(current_text.split()) > 0:
575-
splitted_text.append(current_text)
576-
577-
burnin_steps = self._strategy_config["burnin_steps"] // len(splitted_text)
578-
sampling_steps = self._strategy_config["sampling_steps"] // len(splitted_text)
579-
elif self._strategy_config["split_sentence"] == "1":
580-
splitted_text = splitted_text_ori
581-
burnin_steps = self._strategy_config["burnin_steps"] // len(splitted_text)
582-
sampling_steps = self._strategy_config["sampling_steps"] // len(splitted_text)
583-
else:
584-
assert 0
585-
586-
batch_res = [""] * (batch_size if id != n_batches - 1 else last_batch_size)
587-
for text in splitted_text:
588-
batch = self._parallel_sequential_generation(
589-
text,
590-
text,
591-
batch_size if id != n_batches - 1 else last_batch_size,
592-
burnin_steps,
593-
sampling_steps,
594-
field_name, data_record)
595-
596-
batch_res = [(x + " " + y).strip() for (x, y) in zip(batch_res, batch)]
597-
sentences += batch_res
598-
else:
599-
assert 0
535+
batch = self._parallel_sequential_generation(
536+
clipped_text,
537+
clipped_text if "seed" not in data_record else data_record["seed"],
538+
batch_size if id != n_batches - 1 else last_batch_size,
539+
burnin_steps,
540+
sampling_steps,
541+
field_name, data_record)
542+
sentences += batch
600543

601544
assert len(sentences) == n
602545

fibber/paraphrase_strategies/asrs_utils_text_parser.py

Lines changed: 0 additions & 116 deletions
This file was deleted.

tests/paraphrase_strategies/test_asrs_utils_text_parser.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)