@@ -53,7 +53,6 @@ def __init__(
5353 vocab : Vocabulary ,
5454 bert_model : Union [str , AutoModel ],
5555 embedding_dropout : float = 0.0 ,
56- num_lstms : int = 2 ,
5756 initializer : InitializerApplicator = InitializerApplicator (),
5857 label_smoothing : float = None ,
5958 ignore_span_metric : bool = False ,
@@ -68,8 +67,7 @@ def __init__(
6867 self .frame_role_dict = load_role_frame (FRAME_ROLE_PATH )
6968 self .restrict_frames = restrict_frames
7069 self .restrict_roles = restrict_roles
71- config = AutoConfig .from_pretrained (bert_model , output_hidden_states = True )
72- self .transformer = AutoModel .from_pretrained (bert_model , config = config )
70+ self .transformer = AutoModel .from_pretrained (bert_model )
7371 self .frame_criterion = nn .CrossEntropyLoss ()
7472 # add missing labels
7573 frame_list = load_label_list (FRAME_LIST_PATH )
@@ -83,20 +81,10 @@ def __init__(
8381 else :
8482 self .span_metric = None
8583 self .f1_frame_metric = FBetaMeasure (average = "micro" )
86- self .predicate_embedding = nn .Embedding (num_embeddings = 2 , embedding_dim = 10 )
87- self .lstms = nn .LSTM (
88- config .hidden_size + 10 ,
89- config .hidden_size ,
90- num_layers = num_lstms ,
91- dropout = 0.2 if num_lstms > 1 else 0 ,
92- bidirectional = True ,
93- )
94- # self.dropout = nn.Dropout(0.4)
95- # self.tag_projection_layer = nn.Linear(config.hidden_size, self.num_classes)
96- self .tag_projection_layer = torch .nn .Sequential (
97- nn .Linear (config .hidden_size * 2 , 300 ), nn .ReLU (), nn .Linear (300 , self .num_classes ),
84+ self .tag_projection_layer = nn .Linear (self .transformer .config .hidden_size , self .num_classes )
85+ self .frame_projection_layer = nn .Linear (
86+ self .transformer .config .hidden_size , self .frame_num_classes
9887 )
99- self .frame_projection_layer = nn .Linear (config .hidden_size * 2 , self .frame_num_classes )
10088 self .embedding_dropout = nn .Dropout (p = embedding_dropout )
10189 self ._label_smoothing = label_smoothing
10290 self .ignore_span_metric = ignore_span_metric
@@ -106,7 +94,6 @@ def forward( # type: ignore
10694 self ,
10795 tokens : TextFieldTensors ,
10896 verb_indicator : torch .Tensor ,
109- sentence_end : torch .LongTensor ,
11097 frame_indicator : torch .Tensor ,
11198 metadata : List [Any ],
11299 tags : torch .LongTensor = None ,
@@ -153,36 +140,18 @@ def forward( # type: ignore
153140 """
154141 mask = get_text_field_mask (tokens )
155142 input_ids = util .get_token_ids_from_text_field_tensors (tokens )
156- embeddings = self .transformer (input_ids = input_ids , attention_mask = mask )
157- embeddings = embeddings [2 ][- 4 :]
158- embeddings = torch .stack (embeddings , dim = 0 ).sum (dim = 0 )
159- # get sizes
160- batch_size , _ , _ = embeddings .size ()
143+ bert_embeddings , _ = self .transformer (
144+ input_ids = input_ids , token_type_ids = verb_indicator , attention_mask = mask ,
145+ )
161146 # extract embeddings
162- embedded_text_input = self .embedding_dropout (embeddings )
163- # sentence_mask = (
164- # torch.arange(mask.shape[1]).unsqueeze(0).repeat(batch_size, 1).to(mask.device)
165- # < sentence_end.unsqueeze(1).repeat(1, mask.shape[1])
166- # ).long()
167- # cutoff = sentence_end.max().item()
168-
169- # encoded_text = embedded_text_input
170- # mask = sentence_mask[:, :cutoff].contiguous()
171- # encoded_text = encoded_text[:, :cutoff, :]
172- # tags = tags[:, :cutoff].contiguous()
173- # frame_tags = frame_tags[:, :cutoff].contiguous()
174- # frame_indicator = frame_indicator[:, :cutoff].contiguous()
175-
176- predicate_embeddings = self .predicate_embedding (verb_indicator )
177- # encoded_text = torch.stack((embedded_text_input, predicate_embeddings), dim=0).sum(dim=0)
178- embedded_text_input = torch .cat ((embedded_text_input , predicate_embeddings ), dim = - 1 )
179- encoded_text , _ = self .lstms (embedded_text_input )
180- frame_embeddings = encoded_text [frame_indicator == 1 ]
147+ embedded_text_input = self .embedding_dropout (bert_embeddings )
148+ frame_embeddings = embedded_text_input [frame_indicator == 1 ]
149+ # get sizes
150+ batch_size , sequence_length , _ = embedded_text_input .size ()
181151 # outputs
182- logits = self .tag_projection_layer (encoded_text )
152+ logits = self .tag_projection_layer (embedded_text_input )
183153 frame_logits = self .frame_projection_layer (frame_embeddings )
184154
185- sequence_length = encoded_text .shape [1 ]
186155 reshaped_log_probs = logits .view (- 1 , self .num_classes )
187156 class_probabilities = F .softmax (reshaped_log_probs , dim = - 1 ).view (
188157 [batch_size , sequence_length , self .num_classes ]
0 commit comments