Skip to content

Commit d411156

Browse files
committed
update mentions indexing
1 parent d893317 commit d411156

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

tibert/bertcoref.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def prepared_document(
158158
except ValueError:
159159
continue
160160
new_chain.append(
161-
Mention(tokens[start_idx : end_idx + 1], start_idx, end_idx)
161+
Mention(tokens[start_idx : end_idx + 1], start_idx, end_idx + 1)
162162
)
163163
if len(new_chain) > 0:
164164
new_chains.append(new_chain)
@@ -193,7 +193,7 @@ def from_wpieced_to_tokenized(
193193
new_end_idx = wp_to_token[mention.end_idx]
194194
new_chain.append(
195195
Mention(
196-
tokens[new_start_idx : new_end_idx + 1],
196+
tokens[new_start_idx:new_end_idx],
197197
new_start_idx,
198198
new_end_idx,
199199
)
@@ -227,7 +227,7 @@ def from_labels(
227227
# singleton cluster
228228
if mention_labels[i] == 1:
229229
start_idx, end_idx = spans_idx[i]
230-
mention_tokens = tokens[start_idx : end_idx + 1]
230+
mention_tokens = tokens[start_idx:end_idx]
231231
chains.append([Mention(mention_tokens, start_idx, end_idx)])
232232
already_visited_mentions.append(i)
233233

@@ -237,7 +237,7 @@ def from_labels(
237237
continue
238238

239239
start_idx, end_idx = spans_idx[i]
240-
mention_tokens = tokens[start_idx : end_idx + 1]
240+
mention_tokens = tokens[start_idx:end_idx]
241241
chain = [Mention(mention_tokens, start_idx, end_idx)]
242242

243243
for j, label in enumerate(mlabels):
@@ -246,7 +246,7 @@ def from_labels(
246246
continue
247247

248248
start_idx, end_idx = spans_idx[j]
249-
mention_tokens = tokens[start_idx : end_idx + 1]
249+
mention_tokens = tokens[start_idx:end_idx]
250250
chain.append(Mention(mention_tokens, start_idx, end_idx))
251251
already_visited_mentions.append(j)
252252

@@ -426,9 +426,9 @@ def from_conll2012_file(
426426

427427
if mention_is_ending:
428428
mention_start_idx = open_mentions[chain_id].pop()
429-
mention_end_idx = len(document_tokens) - 1
429+
mention_end_idx = len(document_tokens)
430430
mention = Mention(
431-
document_tokens[mention_start_idx : mention_end_idx + 1],
431+
document_tokens[mention_start_idx:mention_end_idx],
432432
mention_start_idx,
433433
mention_end_idx,
434434
)
@@ -665,7 +665,7 @@ def coreference_documents(
665665
# the antecedent is the dummy mention : maybe we have
666666
# a one-mention chain ?
667667
if top_antecedent_idx == antecedents_nb - 1:
668-
if self.top_mentions_scores[b_i][m_j].item() > 0.0:
668+
if float(self.top_mentions_scores[b_i][m_j].item()) > 0.0:
669669
G.add_node(span_mention)
670670
continue
671671

@@ -907,7 +907,7 @@ def distance_between_spans(self, spans_nb: int, seq_size: int) -> torch.Tensor:
907907

908908
# distance between a span and its antecedent is defined to be
909909
# the span start index minus the antecedent span end index
910-
dist = start_end_idx_combinations[:, 0] - start_end_idx_combinations[:, 1]
910+
dist = start_end_idx_combinations[:, 0] - start_end_idx_combinations[:, 1] + 1
911911
assert dist.shape == (p * p,)
912912
dist = dist.reshape(spans_nb, spans_nb)
913913

tibert/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def spans_indexs(seq: List, max_len: int) -> List[Tuple[int, int]]:
3737
for i in range(1, min(len(seq), max_len + 1)):
3838
for span in windowed(range(len(seq)), i):
3939
span = cast(Tuple[int, ...], span)
40-
indexs.append((min(span), max(span)))
40+
indexs.append((min(span), max(span) + 1))
4141
return indexs
4242

4343

@@ -124,7 +124,7 @@ def pprint_coreference_document(doc: CoreferenceDocument):
124124
related_chains = [
125125
(chain_i, start_i, end_i)
126126
for chain_i, start_i, end_i in mentions
127-
if start_i == token_i or end_i == token_i
127+
if start_i == token_i or end_i - 1 == token_i
128128
]
129129

130130
for chain_i, start_i, _ in related_chains:
@@ -134,7 +134,7 @@ def pprint_coreference_document(doc: CoreferenceDocument):
134134
out.append(token)
135135

136136
for chain_i, _, end_i in related_chains:
137-
if token_i == end_i:
137+
if token_i == end_i - 1:
138138
out.append(f")[/red]")
139139

140140
try:

0 commit comments

Comments
 (0)