Skip to content

Commit fb71b24

Browse files
committed
release 2.4
1 parent 00d407d commit fb71b24

File tree

3 files changed

+13
-46
lines changed

3 files changed

+13
-46
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="transformer_srl", # Replace with your own username
8-
version="3.0rc3",
8+
version="2.4",
99
author="Riccardo Orlando",
1010
author_email="orlandoricc@gmail.com",
1111
description="SRL Transformer model",

transformer_srl/dataset_readers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,13 @@ def text_to_instance( # type: ignore
258258
verb_indicator = SequenceLabelField(new_verbs, text_field)
259259
frame_indicator = SequenceLabelField(frame_indicator, text_field)
260260

261-
sep_index = wordpieces.index(self.tokenizer.sep_token)
262261

263262
metadata_dict["offsets"] = start_offsets
264263

265264
fields: Dict[str, Field] = {
266265
"tokens": text_field,
267266
"verb_indicator": verb_indicator,
268267
"frame_indicator": frame_indicator,
269-
"sentence_end": ArrayField(np.array(sep_index + 1, dtype=np.int64), dtype=np.int64),
270268
}
271269

272270
if all(x == 0 for x in verb_label):

transformer_srl/models.py

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def __init__(
5353
vocab: Vocabulary,
5454
bert_model: Union[str, AutoModel],
5555
embedding_dropout: float = 0.0,
56-
num_lstms: int = 2,
5756
initializer: InitializerApplicator = InitializerApplicator(),
5857
label_smoothing: float = None,
5958
ignore_span_metric: bool = False,
@@ -68,8 +67,7 @@ def __init__(
6867
self.frame_role_dict = load_role_frame(FRAME_ROLE_PATH)
6968
self.restrict_frames = restrict_frames
7069
self.restrict_roles = restrict_roles
71-
config = AutoConfig.from_pretrained(bert_model, output_hidden_states=True)
72-
self.transformer = AutoModel.from_pretrained(bert_model, config=config)
70+
self.transformer = AutoModel.from_pretrained(bert_model)
7371
self.frame_criterion = nn.CrossEntropyLoss()
7472
# add missing labels
7573
frame_list = load_label_list(FRAME_LIST_PATH)
@@ -83,20 +81,10 @@ def __init__(
8381
else:
8482
self.span_metric = None
8583
self.f1_frame_metric = FBetaMeasure(average="micro")
86-
self.predicate_embedding = nn.Embedding(num_embeddings=2, embedding_dim=10)
87-
self.lstms = nn.LSTM(
88-
config.hidden_size + 10,
89-
config.hidden_size,
90-
num_layers=num_lstms,
91-
dropout=0.2 if num_lstms > 1 else 0,
92-
bidirectional=True,
93-
)
94-
# self.dropout = nn.Dropout(0.4)
95-
# self.tag_projection_layer = nn.Linear(config.hidden_size, self.num_classes)
96-
self.tag_projection_layer = torch.nn.Sequential(
97-
nn.Linear(config.hidden_size * 2, 300), nn.ReLU(), nn.Linear(300, self.num_classes),
84+
self.tag_projection_layer = nn.Linear(self.transformer.config.hidden_size, self.num_classes)
85+
self.frame_projection_layer = nn.Linear(
86+
self.transformer.config.hidden_size, self.frame_num_classes
9887
)
99-
self.frame_projection_layer = nn.Linear(config.hidden_size * 2, self.frame_num_classes)
10088
self.embedding_dropout = nn.Dropout(p=embedding_dropout)
10189
self._label_smoothing = label_smoothing
10290
self.ignore_span_metric = ignore_span_metric
@@ -106,7 +94,6 @@ def forward( # type: ignore
10694
self,
10795
tokens: TextFieldTensors,
10896
verb_indicator: torch.Tensor,
109-
sentence_end: torch.LongTensor,
11097
frame_indicator: torch.Tensor,
11198
metadata: List[Any],
11299
tags: torch.LongTensor = None,
@@ -153,36 +140,18 @@ def forward( # type: ignore
153140
"""
154141
mask = get_text_field_mask(tokens)
155142
input_ids = util.get_token_ids_from_text_field_tensors(tokens)
156-
embeddings = self.transformer(input_ids=input_ids, attention_mask=mask)
157-
embeddings = embeddings[2][-4:]
158-
embeddings = torch.stack(embeddings, dim=0).sum(dim=0)
159-
# get sizes
160-
batch_size, _, _ = embeddings.size()
143+
bert_embeddings, _ = self.transformer(
144+
input_ids=input_ids, token_type_ids=verb_indicator, attention_mask=mask,
145+
)
161146
# extract embeddings
162-
embedded_text_input = self.embedding_dropout(embeddings)
163-
# sentence_mask = (
164-
# torch.arange(mask.shape[1]).unsqueeze(0).repeat(batch_size, 1).to(mask.device)
165-
# < sentence_end.unsqueeze(1).repeat(1, mask.shape[1])
166-
# ).long()
167-
# cutoff = sentence_end.max().item()
168-
169-
# encoded_text = embedded_text_input
170-
# mask = sentence_mask[:, :cutoff].contiguous()
171-
# encoded_text = encoded_text[:, :cutoff, :]
172-
# tags = tags[:, :cutoff].contiguous()
173-
# frame_tags = frame_tags[:, :cutoff].contiguous()
174-
# frame_indicator = frame_indicator[:, :cutoff].contiguous()
175-
176-
predicate_embeddings = self.predicate_embedding(verb_indicator)
177-
# encoded_text = torch.stack((embedded_text_input, predicate_embeddings), dim=0).sum(dim=0)
178-
embedded_text_input = torch.cat((embedded_text_input, predicate_embeddings), dim=-1)
179-
encoded_text, _ = self.lstms(embedded_text_input)
180-
frame_embeddings = encoded_text[frame_indicator == 1]
147+
embedded_text_input = self.embedding_dropout(bert_embeddings)
148+
frame_embeddings = embedded_text_input[frame_indicator == 1]
149+
# get sizes
150+
batch_size, sequence_length, _ = embedded_text_input.size()
181151
# outputs
182-
logits = self.tag_projection_layer(encoded_text)
152+
logits = self.tag_projection_layer(embedded_text_input)
183153
frame_logits = self.frame_projection_layer(frame_embeddings)
184154

185-
sequence_length = encoded_text.shape[1]
186155
reshaped_log_probs = logits.view(-1, self.num_classes)
187156
class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
188157
[batch_size, sequence_length, self.num_classes]

0 commit comments

Comments
 (0)