|
1 | 1 | # Implementation of edge probing module. |
| 2 | +from typing import Dict |
2 | 3 |
|
3 | | -from typing import Dict, Iterable |
4 | | - |
5 | | -import numpy as np |
6 | 4 | import torch |
7 | 5 | import torch.nn as nn |
8 | 6 | import torch.nn.functional as F |
@@ -80,19 +78,19 @@ def __init__(self, task, d_inp: int, task_params): |
80 | 78 | if self.is_symmetric or self.single_sided: |
81 | 79 | # Use None as dummy padding for readability, |
82 | 80 | # so that we can index projs[1] and projs[2] |
83 | | - self.projs = [None, self.proj1, self.proj1] |
| 81 | + self.projs = nn.ModuleList([None, self.proj1, self.proj1]) |
84 | 82 | else: |
85 | 83 | # Separate params for span2 |
86 | 84 | self.proj2 = self._make_cnn_layer(d_inp) |
87 | | - self.projs = [None, self.proj1, self.proj2] |
| 85 | + self.projs = nn.ModuleList([None, self.proj1, self.proj2]) |
88 | 86 |
|
89 | 87 | # Span extractor, shared for both span1 and span2. |
90 | 88 | self.span_extractor1 = self._make_span_extractor() |
91 | 89 | if self.is_symmetric or self.single_sided: |
92 | | - self.span_extractors = [None, self.span_extractor1, self.span_extractor1] |
| 90 | + self.span_extractors = nn.ModuleList([None, self.span_extractor1, self.span_extractor1]) |
93 | 91 | else: |
94 | 92 | self.span_extractor2 = self._make_span_extractor() |
95 | | - self.span_extractors = [None, self.span_extractor1, self.span_extractor2] |
| 93 | + self.span_extractors = nn.ModuleList([None, self.span_extractor1, self.span_extractor2]) |
96 | 94 |
|
97 | 95 | # Classifier gets concatenated projections of span1, span2 |
98 | 96 | clf_input_dim = self.span_extractors[1].get_output_dim() |
@@ -131,11 +129,9 @@ def forward( |
131 | 129 | """ |
132 | 130 | out = {} |
133 | 131 |
|
134 | | - batch_size = word_embs_in_context.shape[0] |
135 | | - out["n_inputs"] = batch_size |
136 | | - |
137 | 132 | # Apply projection CNN layer for each span. |
138 | 133 | word_embs_in_context_t = word_embs_in_context.transpose(1, 2) # needed for CNN layer |
| 134 | + |
139 | 135 | se_proj1 = self.projs[1](word_embs_in_context_t).transpose(2, 1).contiguous() |
140 | 136 | if not self.single_sided: |
141 | 137 | se_proj2 = self.projs[2](word_embs_in_context_t).transpose(2, 1).contiguous() |
@@ -169,28 +165,10 @@ def forward( |
169 | 165 | out["loss"] = self.compute_loss(logits[span_mask], batch["labels"][span_mask], task) |
170 | 166 |
|
171 | 167 | if predict: |
172 | | - # Return preds as a list. |
173 | | - preds = self.get_predictions(logits) |
174 | | - out["preds"] = list(self.unbind_predictions(preds, span_mask)) |
| 168 | + out["preds"] = self.get_predictions(logits) |
175 | 169 |
|
176 | 170 | return out |
177 | 171 |
|
178 | | - def unbind_predictions(self, preds: torch.Tensor, masks: torch.Tensor) -> Iterable[np.ndarray]: |
179 | | - """ Unpack preds to varying-length numpy arrays. |
180 | | -
|
181 | | - Args: |
182 | | - preds: [batch_size, num_targets, ...] |
183 | | - masks: [batch_size, num_targets] boolean mask |
184 | | -
|
185 | | - Yields: |
186 | | - np.ndarray for each row of preds, selected by the corresponding row |
187 | | - of span_mask. |
188 | | - """ |
189 | | - preds = preds.detach().cpu() |
190 | | - masks = masks.detach().cpu() |
191 | | - for pred, mask in zip(torch.unbind(preds, dim=0), torch.unbind(masks, dim=0)): |
192 | | - yield pred[mask].numpy() # only non-masked predictions |
193 | | - |
194 | 172 | def get_predictions(self, logits: torch.Tensor): |
195 | 173 | """Return class probabilities, same shape as logits. |
196 | 174 |
|
@@ -218,16 +196,6 @@ def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor, task: EdgePro |
218 | 196 | Returns: |
219 | 197 | loss: scalar Tensor |
220 | 198 | """ |
221 | | - binary_preds = logits.ge(0).long() # {0,1} |
222 | | - |
223 | | - # Matthews coefficient and accuracy computed on {0,1} labels. |
224 | | - task.mcc_scorer(binary_preds, labels.long()) |
225 | | - task.acc_scorer(binary_preds, labels.long()) |
226 | | - |
227 | | - # F1Measure() expects [total_num_targets, n_classes, 2] |
228 | | - # to compute binarized F1. |
229 | | - binary_scores = torch.stack([-1 * logits, logits], dim=2) |
230 | | - task.f1_scorer(binary_scores, labels) |
231 | 199 |
|
232 | 200 | if self.loss_type == "sigmoid": |
233 | 201 | return F.binary_cross_entropy(torch.sigmoid(logits), labels.float()) |
|
0 commit comments