Skip to content

Commit 7450191

Browse files
committed
2 parents 635450e + 08150a7 commit 7450191

File tree

6 files changed

+1014
-0
lines changed

6 files changed

+1014
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# MIT License
4+
#
5+
# Copyright 2018-2022 New York University Abu Dhabi
6+
#
7+
# Permission is hereby granted, free of charge, to any person obtaining a copy
8+
# of this software and associated documentation files (the "Software"), to deal
9+
# in the Software without restriction, including without limitation the rights
10+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
# copies of the Software, and to permit persons to whom the Software is
12+
# furnished to do so, subject to the following conditions:
13+
#
14+
# The above copyright notice and this permission notice shall be included in
15+
# all copies or substantial portions of the Software.
16+
#
17+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
# SOFTWARE.
24+
25+
26+
from camel_tools.disambig.bert.unfactored import BERTUnfactoredDisambiguator
27+
28+
__all__ = [
29+
'BERTUnfactoredDisambiguator',
30+
]
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# MIT License
4+
#
5+
# Copyright 2018-2022 New York University Abu Dhabi
6+
#
7+
# Permission is hereby granted, free of charge, to any person obtaining a copy
8+
# of this software and associated documentation files (the "Software"), to deal
9+
# in the Software without restriction, including without limitation the rights
10+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
# copies of the Software, and to permit persons to whom the Software is
12+
# furnished to do so, subject to the following conditions:
13+
#
14+
# The above copyright notice and this permission notice shall be included in
15+
# all copies or substantial portions of the Software.
16+
#
17+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
# SOFTWARE.
24+
25+
import torch
26+
import torch.nn as nn
27+
from torch.utils.data import Dataset
28+
29+
30+
def _prepare_sentences(sentences, placeholder=''):
31+
"""
32+
Encapsulates the input sentences into PrepSentence
33+
objects.
34+
35+
Args:
36+
sentences (:obj:`list` of :obj:`list` of :obj: `str): The input
37+
sentences.
38+
39+
Returns:
40+
:obj:`list` of :obj:`PrepSentence`: The list of PrepSentence objects.
41+
"""
42+
43+
guid_index = 1
44+
prepared_sentences = []
45+
46+
for words in sentences:
47+
labels = [placeholder]*len(words)
48+
prepared_sentences.append(_PrepSentence(guid=f"{guid_index}",
49+
words=words,
50+
labels=labels))
51+
guid_index += 1
52+
53+
return prepared_sentences
54+
55+
56+
class _PrepSentence:
57+
"""A single input sentence for token classification.
58+
59+
Args:
60+
guid (:obj:`str`): Unique id for the sentence.
61+
words (:obj:`list` of :obj:`str`): list of words of the sentence.
62+
labels (:obj:`list` of :obj:`str`): The labels for each word
63+
of the sentence.
64+
"""
65+
66+
def __init__(self, guid, words, labels):
67+
self.guid = guid
68+
self.words = words
69+
self.labels = labels
70+
71+
72+
class MorphDataset(Dataset):
73+
"""Morph PyTorch Dataset
74+
75+
Args:
76+
sentences (:obj:`list` of :obj:`list` of :obj:`str`): The input
77+
sentences.
78+
tokenizer (:obj:`PreTrainedTokenizer`): Bert's pretrained tokenizer.
79+
labels (:obj:`list` of :obj:`str`): The labels which the model was
80+
trained to classify.
81+
max_seq_length (:obj:`int`): Maximum sentence length.
82+
"""
83+
84+
def __init__(self, sentences, tokenizer, labels, max_seq_length):
85+
prepared_sentences = _prepare_sentences(sentences,
86+
placeholder=labels[0])
87+
# Use cross entropy ignore_index as padding label id so that only
88+
# real label ids contribute to the loss later.
89+
self.pad_token_label_id = nn.CrossEntropyLoss().ignore_index
90+
self.features = self._featurize_input(
91+
prepared_sentences,
92+
labels,
93+
max_seq_length,
94+
tokenizer,
95+
cls_token=tokenizer.cls_token,
96+
sep_token=tokenizer.sep_token,
97+
pad_token=tokenizer.pad_token_id,
98+
pad_token_segment_id=tokenizer.pad_token_type_id,
99+
pad_token_label_id=self.pad_token_label_id,
100+
)
101+
102+
def _featurize_input(self, prepared_sentences, label_list, max_seq_length,
103+
tokenizer, cls_token="[CLS]", cls_token_segment_id=0,
104+
sep_token="[SEP]", pad_token=0, pad_token_segment_id=0,
105+
pad_token_label_id=-100, sequence_a_segment_id=0,
106+
mask_padding_with_zero=True):
107+
"""Featurizes the input which will be fed to the fine-tuned BERT model.
108+
109+
Args:
110+
prepared_sentences (:obj:`list` of :obj:`PrepSentence`): list of
111+
PrepSentence objects.
112+
label_list (:obj:`list` of :obj:`str`): The labels which the model
113+
was trained to classify.
114+
max_seq_length (:obj:`int`): Maximum sequence length.
115+
tokenizer (:obj:`PreTrainedTokenizer`): Bert's pretrained
116+
tokenizer.
117+
cls_token (:obj:`str`): BERT's CLS token. Defaults to [CLS].
118+
cls_token_segment_id (:obj:`int`): BERT's CLS token segment id.
119+
Defaults to 0.
120+
sep_token (:obj:`str`): BERT's CLS token. Defaults to [SEP].
121+
pad_token (:obj:`int`): BERT's pading token. Defaults to 0.
122+
pad_token_segment_id (:obj:`int`): BERT's pading token segment id.
123+
Defaults to 0.
124+
pad_token_label_id (:obj:`int`): BERT's pading token label id.
125+
Defaults to -100.
126+
sequence_a_segment_id (:obj:`int`): BERT's segment id.
127+
Defaults to 0.
128+
mask_padding_with_zero (:obj:`bool`): Whether to masks the padding
129+
tokens with zero or not. Defaults to True.
130+
131+
Returns:
132+
obj:`list` of :obj:`Dict`: list of dicts of the needed features.
133+
"""
134+
135+
label_map = {label: i for i, label in enumerate(label_list)}
136+
features = []
137+
138+
for sent_id, sentence in enumerate(prepared_sentences):
139+
tokens = []
140+
label_ids = []
141+
142+
for word, label in zip(sentence.words, sentence.labels):
143+
word_tokens = tokenizer.tokenize(word)
144+
# bert-base-multilingual-cased sometimes output "nothing ([])
145+
# when calling tokenize with just a space.
146+
if len(word_tokens) > 0:
147+
tokens.append(word_tokens)
148+
# Use the real label id for the first token of the word,
149+
# and padding ids for the remaining tokens
150+
label_ids.append([label_map[label]] +
151+
[pad_token_label_id] *
152+
(len(word_tokens) - 1))
153+
154+
token_segments = []
155+
token_segment = []
156+
label_ids_segments = []
157+
label_ids_segment = []
158+
num_word_pieces = 0
159+
seg_seq_length = max_seq_length - 2
160+
161+
# Dealing with empty sentences
162+
if len(tokens) == 0:
163+
data = self._add_special_tokens(token_segment,
164+
label_ids_segment,
165+
tokenizer,
166+
max_seq_length,
167+
cls_token,
168+
sep_token, pad_token,
169+
cls_token_segment_id,
170+
pad_token_segment_id,
171+
pad_token_label_id,
172+
sequence_a_segment_id,
173+
mask_padding_with_zero)
174+
# Adding sentence id
175+
data['sent_id'] = sent_id
176+
features.append(data)
177+
else:
178+
# Chunking the tokenized sentence into multiple segments
179+
# if it's longer than max_seq_length - 2
180+
for idx, word_pieces in enumerate(tokens):
181+
if num_word_pieces + len(word_pieces) > seg_seq_length:
182+
data = self._add_special_tokens(token_segment,
183+
label_ids_segment,
184+
tokenizer,
185+
max_seq_length,
186+
cls_token,
187+
sep_token, pad_token,
188+
cls_token_segment_id,
189+
pad_token_segment_id,
190+
pad_token_label_id,
191+
sequence_a_segment_id,
192+
mask_padding_with_zero)
193+
# Adding sentence id
194+
data['sent_id'] = sent_id
195+
features.append(data)
196+
197+
token_segments.append(token_segment)
198+
label_ids_segments.append(label_ids_segment)
199+
token_segment = list(word_pieces)
200+
label_ids_segment = list(label_ids[idx])
201+
num_word_pieces = len(word_pieces)
202+
else:
203+
token_segment.extend(word_pieces)
204+
label_ids_segment.extend(label_ids[idx])
205+
num_word_pieces += len(word_pieces)
206+
207+
# Adding the last segment
208+
if len(token_segment) > 0:
209+
data = self._add_special_tokens(token_segment,
210+
label_ids_segment,
211+
tokenizer,
212+
max_seq_length,
213+
cls_token,
214+
sep_token, pad_token,
215+
cls_token_segment_id,
216+
pad_token_segment_id,
217+
pad_token_label_id,
218+
sequence_a_segment_id,
219+
mask_padding_with_zero)
220+
# Adding sentence id
221+
data['sent_id'] = sent_id
222+
features.append(data)
223+
224+
token_segments.append(token_segment)
225+
label_ids_segments.append(label_ids_segment)
226+
227+
# DEBUG: Making sure we got all segments correctly
228+
# assert sum([len(_) for _ in label_ids_segments]) == \
229+
# sum([len(_) for _ in label_ids])
230+
231+
# assert sum([len(_) for _ in token_segments]) == \
232+
# sum([len(_) for _ in tokens])
233+
234+
return features
235+
236+
def _add_special_tokens(self, tokens, label_ids, tokenizer, max_seq_length,
237+
cls_token, sep_token, pad_token,
238+
cls_token_segment_id, pad_token_segment_id,
239+
pad_token_label_id, sequence_a_segment_id,
240+
mask_padding_with_zero):
241+
242+
_tokens = list(tokens)
243+
_label_ids = list(label_ids)
244+
245+
_tokens += [sep_token]
246+
_label_ids += [pad_token_label_id]
247+
segment_ids = [sequence_a_segment_id] * len(_tokens)
248+
249+
_tokens = [cls_token] + _tokens
250+
_label_ids = [pad_token_label_id] + _label_ids
251+
segment_ids = [cls_token_segment_id] + segment_ids
252+
253+
input_ids = tokenizer.convert_tokens_to_ids(_tokens)
254+
255+
# The mask has 1 for real tokens and 0 for padding tokens. Only
256+
# real tokens are attended to.
257+
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
258+
259+
# Zero-pad up to the sequence length.
260+
padding_length = max_seq_length - len(input_ids)
261+
input_ids += [pad_token] * padding_length
262+
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
263+
segment_ids += [pad_token_segment_id] * padding_length
264+
_label_ids += [pad_token_label_id] * padding_length
265+
266+
return {'input_ids': torch.tensor(input_ids),
267+
'attention_mask': torch.tensor(input_mask),
268+
'token_type_ids': torch.tensor(segment_ids),
269+
'label_ids': torch.tensor(_label_ids)}
270+
271+
def __len__(self):
272+
return len(self.features)
273+
274+
def __getitem__(self, i):
275+
return self.features[i]

0 commit comments

Comments
 (0)