diff --git a/pipelines.py b/pipelines.py index 2c9c681..7701fd2 100644 --- a/pipelines.py +++ b/pipelines.py @@ -134,12 +134,14 @@ def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers): for i, answer in enumerate(answers): if len(answer) == 0: continue for answer_text in answer: - sent = sents[i] + sent = sents[i].lower() sents_copy = sents[:] - answer_text = answer_text.strip() - - ans_start_idx = sent.index(answer_text) + answer_text = answer_text.strip().lower() + try: + ans_start_idx = sent.index(answer_text) + except (ValueError,AssertionError): + continue sent = f"{sent[:ans_start_idx]} {answer_text} {sent[ans_start_idx + len(answer_text): ]}" sents_copy[i] = sent