88
99from fibber import log
1010from fibber .metrics .bert_lm_utils import get_lm
11- from fibber .paraphrase_strategies .asrs_utils_text_parser import TextParser
1211from fibber .paraphrase_strategies .asrs_utils_wpe import get_wordpiece_emb
1312from fibber .paraphrase_strategies .strategy_base import StrategyBase
1413
@@ -132,7 +131,7 @@ def ppl_criteria_score(origin, paraphrases, ppl_metric, ppl_weight):
132131 """
133132 if ppl_weight == 0 :
134133 return np .zeros (len (paraphrases ), dtype = "float32" )
135- ppl_ratio = ppl_metric .measure_batch (origin , paraphrases )
134+ ppl_ratio = ppl_metric .measure_batch (origin , paraphrases , use_ratio = True )
136135 return - ppl_weight * (np .maximum (np .asarray (ppl_ratio ) - 1 , 0 ) ** 2 )
137136
138137
@@ -272,11 +271,9 @@ class ASRSStrategy(StrategyBase):
272271 ("burnin_criteria_schedule" , str , "1" , ("the schedule decides how strict the criteria is "
273272 "used. options are [linear, 0, 1]." )),
274273 ("seed_option" , str , "origin" , ("the option for seed sentences in generation. "
275- "choose from [origin, auto, dynamic_len]." )),
274+ "choose from [origin, dynamic_len]." )),
276275 ("dynamic_len_min" , int , - 3 , "change length min." ),
277276 ("dynamic_len_max" , int , 3 , "change length max." ),
278- ("split_sentence" , str , "auto" , "split paragraph to sentence. options are [0, 1, auto]." ),
279- ("stanza_port" , int , 9000 , "stanza port" ),
280277 ("lm_option" , str , "finetune" , "choose from [pretrain, finetune, adv]." ),
281278 ("lm_steps" , int , 5000 , "lm training steps." ),
282279 ("clf_weight" , float , 3 , "weight for the clf score in the criteria." ),
@@ -304,7 +301,7 @@ def fit(self, trainset):
304301 self ._sim_metric = self ._metric_bundle .get_metric (
305302 self ._strategy_config ["sim_metric" ])
306303 self ._clf_metric = self ._metric_bundle .get_target_classifier ()
307- self ._ppl_metric = self ._metric_bundle .get_metric ("GPT2PerplexityMetric " )
304+ self ._ppl_metric = self ._metric_bundle .get_metric ("BertPerplexityMetric " )
308305
309306 # load word piece embeddings.
310307 wpe = get_wordpiece_emb (self ._output_dir , self ._dataset_name , trainset , self ._device )
@@ -332,13 +329,6 @@ def fit(self, trainset):
332329 else :
333330 assert 0
334331
335- # load text parser
336- if (self ._strategy_config ["seed_option" ] != "origin"
337- or self ._strategy_config ["split_sentence" ] != "0" ):
338- self ._text_parser = TextParser (self ._strategy_config ["stanza_port" ])
339- else :
340- self ._text_parser = None
341-
342332 self ._stats = {
343333 "all" : 0 ,
344334 "accept" : 0
@@ -351,14 +341,6 @@ def _parallel_sequential_generation(self, original_text, seed, batch_size, burni
351341 batch_tensor = torch .tensor (
352342 [self ._tokenizer .convert_tokens_to_ids (seq )] * batch_size ).to (self ._device )
353343 seq_len = [len (seq )] * batch_size
354- elif self ._strategy_config ["seed_option" ] == "auto" :
355- seeds = self ._text_parser .phrase_level_shuffle (seed , batch_size )
356- seq = [self ._tokenizer .tokenize (x ) for x in seeds ]
357- seq_len = ([len (x ) + 2 for x in seq ])
358- max_len = max (seq_len )
359- seq = [["[CLS]" ] + x + ["[SEP]" ] + ["[PAD]" ] * (max_len - len (x ) - 2 ) for x in seq ]
360- batch_tensor = torch .tensor (
361- [self ._tokenizer .convert_tokens_to_ids (x ) for x in seq ]).to (self ._device )
362344 elif self ._strategy_config ["seed_option" ] == "dynamic_len" :
363345 seq_raw = self ._tokenizer .tokenize (seed )
364346 seq = []
@@ -536,7 +518,7 @@ def paraphrase_example(self, data_record, field_name, n):
536518 self ._bert_lm = self ._bert_lms [data_record ["label" ]]
537519 self ._bert_lm .to (self ._device )
538520
539- clipped_text = " " . join ( data_record [field_name ]. split ()[: 200 ])
521+ clipped_text = data_record [field_name ]
540522 clipped_text = process_text (clipped_text , PRE_PROCESSING_PATTERN )
541523 batch_size = self ._strategy_config ["batch_size" ]
542524
@@ -550,53 +532,14 @@ def paraphrase_example(self, data_record, field_name, n):
550532 burnin_steps = self ._strategy_config ["burnin_steps" ]
551533 sampling_steps = self ._strategy_config ["sampling_steps" ]
552534
553- if self ._strategy_config ["split_sentence" ] == "0" :
554- batch = self ._parallel_sequential_generation (
555- clipped_text ,
556- clipped_text if "seed" not in data_record else data_record ["seed" ],
557- batch_size if id != n_batches - 1 else last_batch_size ,
558- burnin_steps ,
559- sampling_steps ,
560- field_name , data_record )
561- sentences += batch
562- elif self ._strategy_config ["split_sentence" ] in ["1" , "auto" ]:
563- splitted_text_ori = self ._text_parser .split_paragraph_to_sentences (
564- clipped_text )
565-
566- if self ._strategy_config ["split_sentence" ] == "auto" :
567- splitted_text = []
568- current_text = ""
569- for s in splitted_text_ori :
570- current_text += " " + s
571- if len (current_text .split ()) > AUTO_SENTENCE_LEN_THRESHOLD :
572- splitted_text .append (current_text )
573- current_text = ""
574- if len (current_text .split ()) > 0 :
575- splitted_text .append (current_text )
576-
577- burnin_steps = self ._strategy_config ["burnin_steps" ] // len (splitted_text )
578- sampling_steps = self ._strategy_config ["sampling_steps" ] // len (splitted_text )
579- elif self ._strategy_config ["split_sentence" ] == "1" :
580- splitted_text = splitted_text_ori
581- burnin_steps = self ._strategy_config ["burnin_steps" ] // len (splitted_text )
582- sampling_steps = self ._strategy_config ["sampling_steps" ] // len (splitted_text )
583- else :
584- assert 0
585-
586- batch_res = ["" ] * (batch_size if id != n_batches - 1 else last_batch_size )
587- for text in splitted_text :
588- batch = self ._parallel_sequential_generation (
589- text ,
590- text ,
591- batch_size if id != n_batches - 1 else last_batch_size ,
592- burnin_steps ,
593- sampling_steps ,
594- field_name , data_record )
595-
596- batch_res = [(x + " " + y ).strip () for (x , y ) in zip (batch_res , batch )]
597- sentences += batch_res
598- else :
599- assert 0
535+ batch = self ._parallel_sequential_generation (
536+ clipped_text ,
537+ clipped_text if "seed" not in data_record else data_record ["seed" ],
538+ batch_size if id != n_batches - 1 else last_batch_size ,
539+ burnin_steps ,
540+ sampling_steps ,
541+ field_name , data_record )
542+ sentences += batch
600543
601544 assert len (sentences ) == n
602545
0 commit comments