|
20 | 20 | import os |
21 | 21 | import subprocess |
22 | 22 | import sys |
| 23 | +import threading |
23 | 24 | import unittest |
| 25 | +from collections import defaultdict |
24 | 26 | from collections.abc import Mapping |
25 | 27 | from contextlib import contextmanager |
| 28 | +from unittest.mock import patch |
26 | 29 |
|
27 | 30 | import numpy as np |
28 | 31 | import paddle |
29 | 32 | import paddle.distributed.fleet as fleet |
| 33 | +import urllib3 |
30 | 34 | import yaml |
31 | 35 |
|
32 | 36 | from paddlenlp.trainer.argparser import strtobool |
33 | | -from paddlenlp.utils.import_utils import is_package_available, is_paddle_available |
| 37 | +from paddlenlp.utils.import_utils import ( |
| 38 | + is_package_available, |
| 39 | + is_paddle_available, |
| 40 | + is_tokenizers_available, |
| 41 | +) |
| 42 | + |
| 43 | +SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" |
| 44 | +DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown" |
| 45 | +DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" |
| 46 | +# Used to test Auto{Config, Model, Tokenizer} model_type detection. |
34 | 47 |
|
35 | 48 | __all__ = ["get_vocab_list", "stable_softmax", "cross_entropy"] |
36 | 49 |
|
@@ -539,3 +552,65 @@ def init_dist_env(self, config: dict = {}): |
539 | 552 |
|
540 | 553 | fleet.init(is_collective=True, strategy=strategy) |
541 | 554 | fleet.get_hybrid_communicate_group() |
| 555 | + |
| 556 | + |
| 557 | +def require_tokenizers(test_case): |
| 558 | + """ |
| 559 | + Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed. |
| 560 | + """ |
| 561 | + return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case) |
| 562 | + |
| 563 | + |
| 564 | +class RequestCounter: |
| 565 | + """ |
| 566 | + Helper class that will count all requests made online. |
| 567 | +
|
| 568 | + Might not be robust if urllib3 changes its logging format but should be good enough for us. |
| 569 | +
|
| 570 | + Usage: |
| 571 | + ```py |
| 572 | + with RequestCounter() as counter: |
| 573 | + _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") |
| 574 | + assert counter["GET"] == 0 |
| 575 | + assert counter["HEAD"] == 1 |
| 576 | + assert counter.total_calls == 1 |
| 577 | + ``` |
| 578 | + """ |
| 579 | + |
| 580 | + def __enter__(self): |
| 581 | + self._counter = defaultdict(int) |
| 582 | + self._thread_id = threading.get_ident() |
| 583 | + self._extra_info = [] |
| 584 | + |
| 585 | + def patched_with_thread_info(func): |
| 586 | + def wrap(*args, **kwargs): |
| 587 | + self._extra_info.append(threading.get_ident()) |
| 588 | + return func(*args, **kwargs) |
| 589 | + |
| 590 | + return wrap |
| 591 | + |
| 592 | + self.patcher = patch.object( |
| 593 | + urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug) |
| 594 | + ) |
| 595 | + self.mock = self.patcher.start() |
| 596 | + return self |
| 597 | + |
| 598 | + def __exit__(self, *args, **kwargs) -> None: |
| 599 | + assert len(self.mock.call_args_list) == len(self._extra_info) |
| 600 | + |
| 601 | + for thread_id, call in zip(self._extra_info, self.mock.call_args_list): |
| 602 | + if thread_id != self._thread_id: |
| 603 | + continue |
| 604 | + log = call.args[0] % call.args[1:] |
| 605 | + for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"): |
| 606 | + if method in log: |
| 607 | + self._counter[method] += 1 |
| 608 | + break |
| 609 | + self.patcher.stop() |
| 610 | + |
| 611 | + def __getitem__(self, key: str) -> int: |
| 612 | + return self._counter[key] |
| 613 | + |
| 614 | + @property |
| 615 | + def total_calls(self) -> int: |
| 616 | + return sum(self._counter.values()) |
0 commit comments