@@ -76,21 +76,21 @@ class CoreferenceDocument:
7676 def __len__ (self ) -> int :
7777 return len (self .tokens )
7878
79- def coref_labels (self , max_span_size : int ) -> List [ List [ int ]] :
79+ def coref_labels (self , max_span_size : int ) -> torch . Tensor :
8080 """
81- :return: a list of shape ``(spans_nb, spans_nb + 1)``.
82- when ``out[i][j] == 1``, span j is the preceding
83- coreferent mention if span i. when ``j == spans_nb``,
84- i has no preceding coreferent mention.
81+ :return: a sparse COO tensor of shape ``(spans_nb, spans_nb +
82+ 1)``. when ``out[i][j] == 1``, span j is the
83+ preceding coreferent mention if span i. when ``j ==
84+ spans_nb``, i has no preceding coreferent mention.
8585 """
8686 spans_idx = {
8787 indices : i
8888 for i , indices in enumerate (spans_indexs (self .tokens , max_span_size ))
8989 }
9090 spans_nb = len (spans_idx )
9191
92- # labels = [[0] * (spans_nb + 1) for _ in range(spans_nb) ]
93- labels = np . zeros (( spans_nb , spans_nb + 1 ))
92+ label_indices = []
93+ label_values = []
9494
9595 # spans in a coref chain : mark all antecedents
9696 for chain in self .coref_chains :
@@ -110,24 +110,35 @@ def coref_labels(self, max_span_size: int) -> List[List[int]]:
110110 other_mention_idx = spans_idx [
111111 (other_mention .start_idx , other_mention .end_idx )
112112 ]
113- labels [mention_idx ][other_mention_idx ] = 1
113+ label_indices .append ([mention_idx , other_mention_idx ])
114+ label_values .append (1 )
115+
116+ if len (label_indices ) == 0 :
117+ labels_t = torch .sparse_coo_tensor (size = (spans_nb , spans_nb )) # type: ignore
118+ else :
119+ labels_t = torch .sparse_coo_tensor (
120+ torch .tensor (label_indices ).t (), label_values , (spans_nb , spans_nb )
121+ )
114122
115123 # spans without preceding mentions : mark preceding mention to
116124 # be the null span
117- for i in range (len (labels )):
118- if labels [i ].sum () == 0 :
119- labels [i ][spans_nb ] = 1
125+ null_t = torch .zeros (spans_nb , 1 )
126+ for i in range (spans_nb ):
127+ if labels_t [i ].sum () == 0 :
128+ null_t [i ][0 ] = 1
129+ labels_t = torch .cat ([labels_t , null_t .to_sparse_coo ()], dim = 1 )
130+ assert labels_t .shape == (spans_nb , spans_nb + 1 )
120131
121- return labels . tolist ()
132+ return labels_t
122133
123- def mention_labels (self , max_span_size : int ) -> List [ int ] :
134+ def mention_labels (self , max_span_size : int ) -> torch . Tensor :
124135 """
125136 :return: a list of shape ``(spans_nb)``
126137 """
127138 spans_idx = spans_indexs (self .tokens , max_span_size )
128139 spans_nb = len (spans_idx )
129140
130- labels = [ 0 for _ in range (spans_nb )]
141+ labels = torch . zeros (spans_nb )
131142
132143 for chain in self .coref_chains :
133144 for mention in chain :
@@ -142,7 +153,7 @@ def mention_labels(self, max_span_size: int) -> List[int]:
142153
143154 return labels
144155
145- def document_labels (self , max_span_size : int ) -> Tuple [List [ List [ int ]], List [ int ] ]:
156+ def document_labels (self , max_span_size : int ) -> Tuple [torch . Tensor , torch . Tensor ]:
146157 return (self .coref_labels (max_span_size ), self .mention_labels (max_span_size ))
147158
148159 def prepared_document (
@@ -272,14 +283,14 @@ def from_wpieced_to_tokenized(
272283 @staticmethod
273284 def from_labels (
274285 tokens : List [str ],
275- coref_labels : List [ List [ int ]] ,
276- mention_labels : List [ int ] ,
286+ coref_labels : torch . Tensor ,
287+ mention_labels : torch . Tensor ,
277288 max_span_size : int ,
278289 ) -> CoreferenceDocument :
279290 """Construct a CoreferenceDocument using labels
280291
281292 :param tokens:
282- :param coref_labels: ``(spans_nb, spans_nb + 1)``
293+ :param coref_labels: sparse tensor of shape ``(spans_nb, spans_nb + 1)``
283294 :param mention_labels: ``(spans_nb)``
284295 :param max_span_size:
285296 """
@@ -403,18 +414,25 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]:
403414
404415 for document , tokens in zip (documents , batch ["input_ids" ]): # type: ignore
405416 document .tokens = tokens
406-
407417 labels = [doc .document_labels (self .max_span_size ) for doc in documents ]
408- batch ["coref_labels" ] = [coref_labels for coref_labels , _ in labels ]
409- batch ["mention_labels" ] = [mention_labels for _ , mention_labels in labels ]
410418
411- return BatchEncoding (
419+ del batch ["coref_labels" ]
420+ del batch ["mention_labels" ]
421+ batch = BatchEncoding (
412422 {
413423 k : torch .tensor (v , dtype = torch .int64 , device = torch .device (self .device ))
414424 for k , v in batch .items ()
415425 },
416426 encoding = batch .encodings ,
417427 )
428+ batch ["coref_labels" ] = torch .stack (
429+ [coref_labels for coref_labels , _ in labels ]
430+ )
431+ batch ["mention_labels" ] = torch .stack (
432+ [mention_labels for _ , mention_labels in labels ]
433+ )
434+
435+ return batch
418436
419437
420438class CoreferenceDataset (Dataset ):
@@ -1346,7 +1364,7 @@ def forward(
13461364 :param attention_mask: a tensor of shape ``(b, s)``
13471365 :param token_type_ids: a tensor of shape ``(b, s)``
13481366 :param position_ids: a tensor of shape ``(b, s)``
1349- :param coref_labels: a tensor of shape ``(b, p, p)``
1367+ :param coref_labels: a sparse tensor of shape ``(b, p, p)``
13501368 :param mention_labels: a tensor of shape ``(b, p)``
13511369 :param return_hidden_state: if ``True``, set the hidden_state of
13521370 ``BertCoreferenceResolutionOutput``
@@ -1478,17 +1496,32 @@ def forward(
14781496 # -- loss computation --
14791497 loss = None
14801498 if coref_labels is not None and mention_labels is not None :
1499+
14811500 # -- coref loss
1482- selected_coref_labels = batch_index_select (
1483- coref_labels , 1 , top_mentions_index
1501+
1502+ # NOTE: we have to rely on such a loop, as torch.gather
1503+ # cannot be used on sparse tensors, which prevents using
1504+ # batch_index_select
1505+ selected_coref_labels = torch .stack (
1506+ [
1507+ torch .index_select (coref_labels [b_i ], 0 , top_mentions_index [b_i ])
1508+ for b_i in range (b )
1509+ ]
14841510 )
14851511 assert selected_coref_labels .shape == (b , m , p + 1 )
14861512
1513+ # NOTE: ideally, we should convert selected_mention_labels
1514+ # to a dense tensor _after_ the selection, with a tensor
1515+ # of shape (b, m, a). However, since we can't flatten a
1516+ # sparse tensor, we did not find a way to write the
1517+ # selection below using a sparse tensor.
1518+ selected_coref_labels = selected_coref_labels .to_dense ()
14871519 selected_coref_labels = batch_index_select (
14881520 selected_coref_labels .flatten (start_dim = 0 , end_dim = 1 ),
14891521 1 ,
14901522 top_antecedents_index ,
14911523 ).reshape (b , m , a )
1524+ assert selected_coref_labels .shape == (b , m , a )
14921525
14931526 # mentions with no antecedents are assumed to have the dummy antecedent
14941527 dummy_labels = (1 - selected_coref_labels ).prod (- 1 , keepdim = True )
0 commit comments