Skip to content

Commit 9b05fcf

Browse files
committed
batch["coref_mentions"] is now a sparse tensor to optimize computations/storage
1 parent 947c289 commit 9b05fcf

File tree

1 file changed

+58
-25
lines changed

1 file changed

+58
-25
lines changed

tibert/bertcoref.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,21 @@ class CoreferenceDocument:
7676
def __len__(self) -> int:
7777
return len(self.tokens)
7878

79-
def coref_labels(self, max_span_size: int) -> List[List[int]]:
79+
def coref_labels(self, max_span_size: int) -> torch.Tensor:
8080
"""
81-
:return: a list of shape ``(spans_nb, spans_nb + 1)``.
82-
when ``out[i][j] == 1``, span j is the preceding
83-
coreferent mention if span i. when ``j == spans_nb``,
84-
i has no preceding coreferent mention.
81+
:return: a sparse COO tensor of shape ``(spans_nb, spans_nb +
82+
1)``. when ``out[i][j] == 1``, span j is the
83+
preceding coreferent mention if span i. when ``j ==
84+
spans_nb``, i has no preceding coreferent mention.
8585
"""
8686
spans_idx = {
8787
indices: i
8888
for i, indices in enumerate(spans_indexs(self.tokens, max_span_size))
8989
}
9090
spans_nb = len(spans_idx)
9191

92-
# labels = [[0] * (spans_nb + 1) for _ in range(spans_nb)]
93-
labels = np.zeros((spans_nb, spans_nb + 1))
92+
label_indices = []
93+
label_values = []
9494

9595
# spans in a coref chain : mark all antecedents
9696
for chain in self.coref_chains:
@@ -110,24 +110,35 @@ def coref_labels(self, max_span_size: int) -> List[List[int]]:
110110
other_mention_idx = spans_idx[
111111
(other_mention.start_idx, other_mention.end_idx)
112112
]
113-
labels[mention_idx][other_mention_idx] = 1
113+
label_indices.append([mention_idx, other_mention_idx])
114+
label_values.append(1)
115+
116+
if len(label_indices) == 0:
117+
labels_t = torch.sparse_coo_tensor(size=(spans_nb, spans_nb)) # type: ignore
118+
else:
119+
labels_t = torch.sparse_coo_tensor(
120+
torch.tensor(label_indices).t(), label_values, (spans_nb, spans_nb)
121+
)
114122

115123
# spans without preceding mentions : mark preceding mention to
116124
# be the null span
117-
for i in range(len(labels)):
118-
if labels[i].sum() == 0:
119-
labels[i][spans_nb] = 1
125+
null_t = torch.zeros(spans_nb, 1)
126+
for i in range(spans_nb):
127+
if labels_t[i].sum() == 0:
128+
null_t[i][0] = 1
129+
labels_t = torch.cat([labels_t, null_t.to_sparse_coo()], dim=1)
130+
assert labels_t.shape == (spans_nb, spans_nb + 1)
120131

121-
return labels.tolist()
132+
return labels_t
122133

123-
def mention_labels(self, max_span_size: int) -> List[int]:
134+
def mention_labels(self, max_span_size: int) -> torch.Tensor:
124135
"""
125136
:return: a list of shape ``(spans_nb)``
126137
"""
127138
spans_idx = spans_indexs(self.tokens, max_span_size)
128139
spans_nb = len(spans_idx)
129140

130-
labels = [0 for _ in range(spans_nb)]
141+
labels = torch.zeros(spans_nb)
131142

132143
for chain in self.coref_chains:
133144
for mention in chain:
@@ -142,7 +153,7 @@ def mention_labels(self, max_span_size: int) -> List[int]:
142153

143154
return labels
144155

145-
def document_labels(self, max_span_size: int) -> Tuple[List[List[int]], List[int]]:
156+
def document_labels(self, max_span_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
146157
return (self.coref_labels(max_span_size), self.mention_labels(max_span_size))
147158

148159
def prepared_document(
@@ -272,14 +283,14 @@ def from_wpieced_to_tokenized(
272283
@staticmethod
273284
def from_labels(
274285
tokens: List[str],
275-
coref_labels: List[List[int]],
276-
mention_labels: List[int],
286+
coref_labels: torch.Tensor,
287+
mention_labels: torch.Tensor,
277288
max_span_size: int,
278289
) -> CoreferenceDocument:
279290
"""Construct a CoreferenceDocument using labels
280291
281292
:param tokens:
282-
:param coref_labels: ``(spans_nb, spans_nb + 1)``
293+
:param coref_labels: sparse tensor of shape ``(spans_nb, spans_nb + 1)``
283294
:param mention_labels: ``(spans_nb)``
284295
:param max_span_size:
285296
"""
@@ -403,18 +414,25 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]:
403414

404415
for document, tokens in zip(documents, batch["input_ids"]): # type: ignore
405416
document.tokens = tokens
406-
407417
labels = [doc.document_labels(self.max_span_size) for doc in documents]
408-
batch["coref_labels"] = [coref_labels for coref_labels, _ in labels]
409-
batch["mention_labels"] = [mention_labels for _, mention_labels in labels]
410418

411-
return BatchEncoding(
419+
del batch["coref_labels"]
420+
del batch["mention_labels"]
421+
batch = BatchEncoding(
412422
{
413423
k: torch.tensor(v, dtype=torch.int64, device=torch.device(self.device))
414424
for k, v in batch.items()
415425
},
416426
encoding=batch.encodings,
417427
)
428+
batch["coref_labels"] = torch.stack(
429+
[coref_labels for coref_labels, _ in labels]
430+
)
431+
batch["mention_labels"] = torch.stack(
432+
[mention_labels for _, mention_labels in labels]
433+
)
434+
435+
return batch
418436

419437

420438
class CoreferenceDataset(Dataset):
@@ -1346,7 +1364,7 @@ def forward(
13461364
:param attention_mask: a tensor of shape ``(b, s)``
13471365
:param token_type_ids: a tensor of shape ``(b, s)``
13481366
:param position_ids: a tensor of shape ``(b, s)``
1349-
:param coref_labels: a tensor of shape ``(b, p, p)``
1367+
:param coref_labels: a sparse tensor of shape ``(b, p, p)``
13501368
:param mention_labels: a tensor of shape ``(b, p)``
13511369
:param return_hidden_state: if ``True``, set the hidden_state of
13521370
``BertCoreferenceResolutionOutput``
@@ -1478,17 +1496,32 @@ def forward(
14781496
# -- loss computation --
14791497
loss = None
14801498
if coref_labels is not None and mention_labels is not None:
1499+
14811500
# -- coref loss
1482-
selected_coref_labels = batch_index_select(
1483-
coref_labels, 1, top_mentions_index
1501+
1502+
# NOTE: we have to rely on such a loop, as torch.gather
1503+
# cannot be used on sparse tensors, which prevents using
1504+
# batch_index_select
1505+
selected_coref_labels = torch.stack(
1506+
[
1507+
torch.index_select(coref_labels[b_i], 0, top_mentions_index[b_i])
1508+
for b_i in range(b)
1509+
]
14841510
)
14851511
assert selected_coref_labels.shape == (b, m, p + 1)
14861512

1513+
# NOTE: ideally, we should convert selected_mention_labels
1514+
# to a dense tensor _after_ the selection, with a tensor
1515+
# of shape (b, m, a). However, since we can't flatten a
1516+
# sparse tensor, we did not find a way to write the
1517+
# selection below using a sparse tensor.
1518+
selected_coref_labels = selected_coref_labels.to_dense()
14871519
selected_coref_labels = batch_index_select(
14881520
selected_coref_labels.flatten(start_dim=0, end_dim=1),
14891521
1,
14901522
top_antecedents_index,
14911523
).reshape(b, m, a)
1524+
assert selected_coref_labels.shape == (b, m, a)
14921525

14931526
# mentions with no antecedents are assumed to have the dummy antecedent
14941527
dummy_labels = (1 - selected_coref_labels).prod(-1, keepdim=True)

0 commit comments

Comments
 (0)