Skip to content

Commit 148e564

Browse files
committed
inhibit torch.index_reduce_ warning
1 parent 97329fb commit 148e564

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

tibert/bertcoref.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
TypeVar,
1111
Union,
1212
)
13-
import re, glob, os
13+
import re, glob, os, warnings
1414
from collections import defaultdict
1515
from pathlib import Path
1616
from dataclasses import dataclass
@@ -332,9 +332,9 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]:
332332
# same length yet.
333333
return_tensors=None,
334334
)
335-
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = (
336-
warning_state
337-
)
335+
self.tokenizer.deprecation_warnings[
336+
"Asking-to-pad-a-fast-tokenizer"
337+
] = warning_state
338338

339339
# keep encoding info
340340
batch._encodings = [f.encodings[0] for f in features]
@@ -1349,9 +1349,11 @@ def wordreduce_embeddings(
13491349
batch_word_ids = batch_word_ids[token_mask]
13501350
words_nb = len(set(batch_word_ids.tolist()))
13511351
words = torch.zeros(words_nb, h, device=device)
1352-
words.index_reduce_(
1353-
0, batch_word_ids, batch_encoded, "mean", include_self=False
1354-
)
1352+
with warnings.catch_warnings():
1353+
warnings.simplefilter("ignore")
1354+
words.index_reduce_(
1355+
0, batch_word_ids, batch_encoded, "mean", include_self=False
1356+
)
13551357
word_encoded.append(words)
13561358

13571359
# each 2D-tensor in word_encoded is of shape (w_i, h). w_i

0 commit comments

Comments
 (0)