|
10 | 10 | TypeVar, |
11 | 11 | Union, |
12 | 12 | ) |
13 | | -import re, glob, os |
| 13 | +import re, glob, os, warnings |
14 | 14 | from collections import defaultdict |
15 | 15 | from pathlib import Path |
16 | 16 | from dataclasses import dataclass |
@@ -332,9 +332,9 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]: |
332 | 332 | # same length yet. |
333 | 333 | return_tensors=None, |
334 | 334 | ) |
335 | | - self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = ( |
336 | | - warning_state |
337 | | - ) |
| 335 | + self.tokenizer.deprecation_warnings[ |
| 336 | + "Asking-to-pad-a-fast-tokenizer" |
| 337 | + ] = warning_state |
338 | 338 |
|
339 | 339 | # keep encoding info |
340 | 340 | batch._encodings = [f.encodings[0] for f in features] |
@@ -1349,9 +1349,11 @@ def wordreduce_embeddings( |
1349 | 1349 | batch_word_ids = batch_word_ids[token_mask] |
1350 | 1350 | words_nb = len(set(batch_word_ids.tolist())) |
1351 | 1351 | words = torch.zeros(words_nb, h, device=device) |
1352 | | - words.index_reduce_( |
1353 | | - 0, batch_word_ids, batch_encoded, "mean", include_self=False |
1354 | | - ) |
| 1352 | + with warnings.catch_warnings(): |
| 1353 | + warnings.simplefilter("ignore") |
| 1354 | + words.index_reduce_( |
| 1355 | + 0, batch_word_ids, batch_encoded, "mean", include_self=False |
| 1356 | + ) |
1355 | 1357 | word_encoded.append(words) |
1356 | 1358 |
|
1357 | 1359 | # each 2D-tensor in word_encoded is of shape (w_i, h). w_i |
|
0 commit comments