diff --git a/JaQuAD.ipynb b/JaQuAD.ipynb index 10cfbc9..fde85ce 100644 --- a/JaQuAD.ipynb +++ b/JaQuAD.ipynb @@ -125,7 +125,7 @@ " 'batch_size': 32, # <=32 for TPUv2-8\n", " 'lr': 2e-5, # Learning Rate\n", " 'max_length': 384, # Max Length input size\n", - " 'doc_stride': 128, # The interval of the context when splitting is needed\n", + " 'doc_stride': 128, # The overlap of the context when splitting is needed\n", " 'epochs': 4, # Max Epochs\n", " 'dataset': 'SkelterLabsInc/JaQuAD',\n", " 'huggingface_auth_token': None,\n", @@ -195,7 +195,8 @@ " val += [padding] * pad_len\n", " return val\n", "\n", - " for i in range(0, input_len - max_seq_len + stride, stride):\n", + " step = max_seq_len - question_len - stride\n", + " for i in range(0, max(context_len - stride, step), step):\n", " span = {key: make_value(val, i) for key, val in inputs.items()}\n", " answer_start = answer_start_position - i\n", " answer_end = answer_end_position - i\n", @@ -482,11 +483,12 @@ "\n", " ctx_start = tokens.index(self.tokenizer.sep_token_id) + 1\n", " answer_start_index = ctx_start\n", - " answer_end_index = len(offsets) - 1\n", - " while offsets[answer_start_index][0] < start_char:\n", + " while offsets[answer_start_index][1] < start_char:\n", " answer_start_index += 1\n", - " while offsets[answer_end_index][1] > start_char + len(answer):\n", - " answer_end_index -= 1\n", + " answer_end_index = answer_start_index\n", + " while answer_end_index < len(offsets) \\\n", + " and offsets[answer_end_index][0] < start_char + len(answer):\n", + " answer_end_index += 1\n", "\n", " span_inputs = {\n", " 'input_ids': tokens,\n", @@ -660,7 +662,7 @@ }, "outputs": [], "source": [ - "def get_answers(model: AutoModelForQuestionAnswering,\n", + "def get_answers(model: QAModel,\n", " context: str,\n", " question: str,\n", " n_best_size: int = 5,\n", @@ -686,7 +688,7 @@ " 1:-1].tolist()\n", " end_indexes = np.argsort(end_logits)[-1:-n_best_size - 1:-1].tolist()\n", " cur_offsets = offsets[i:]\n", - " i += doc_stride\n", + " i += max_seq_len - question_len - doc_stride\n", " for start_index in start_indexes:\n", " for end_index in end_indexes:\n", " if 0 < start_index <= end_index < len(cur_offsets):\n",