Skip to content

Commit 579d957

Browse files
committed
add test tokenizer
1 parent 30d29c2 commit 579d957

File tree

2 files changed

+441
-14
lines changed

2 files changed

+441
-14
lines changed

tests/testing_utils.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,30 @@
2020
import os
2121
import subprocess
2222
import sys
23+
import threading
2324
import unittest
25+
from collections import defaultdict
2426
from collections.abc import Mapping
2527
from contextlib import contextmanager
28+
from unittest.mock import patch
2629

2730
import numpy as np
2831
import paddle
2932
import paddle.distributed.fleet as fleet
33+
import urllib3
3034
import yaml
3135

3236
from paddlenlp.trainer.argparser import strtobool
33-
from paddlenlp.utils.import_utils import is_package_available, is_paddle_available
37+
from paddlenlp.utils.import_utils import (
38+
is_package_available,
39+
is_paddle_available,
40+
is_tokenizers_available,
41+
)
42+
43+
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
44+
DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
45+
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
46+
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
3447

3548
__all__ = ["get_vocab_list", "stable_softmax", "cross_entropy"]
3649

@@ -539,3 +552,65 @@ def init_dist_env(self, config: dict = {}):
539552

540553
fleet.init(is_collective=True, strategy=strategy)
541554
fleet.get_hybrid_communicate_group()
555+
556+
557+
def require_tokenizers(test_case):
558+
"""
559+
Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
560+
"""
561+
return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
562+
563+
564+
class RequestCounter:
565+
"""
566+
Helper class that will count all requests made online.
567+
568+
Might not be robust if urllib3 changes its logging format but should be good enough for us.
569+
570+
Usage:
571+
```py
572+
with RequestCounter() as counter:
573+
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
574+
assert counter["GET"] == 0
575+
assert counter["HEAD"] == 1
576+
assert counter.total_calls == 1
577+
```
578+
"""
579+
580+
def __enter__(self):
581+
self._counter = defaultdict(int)
582+
self._thread_id = threading.get_ident()
583+
self._extra_info = []
584+
585+
def patched_with_thread_info(func):
586+
def wrap(*args, **kwargs):
587+
self._extra_info.append(threading.get_ident())
588+
return func(*args, **kwargs)
589+
590+
return wrap
591+
592+
self.patcher = patch.object(
593+
urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug)
594+
)
595+
self.mock = self.patcher.start()
596+
return self
597+
598+
def __exit__(self, *args, **kwargs) -> None:
599+
assert len(self.mock.call_args_list) == len(self._extra_info)
600+
601+
for thread_id, call in zip(self._extra_info, self.mock.call_args_list):
602+
if thread_id != self._thread_id:
603+
continue
604+
log = call.args[0] % call.args[1:]
605+
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
606+
if method in log:
607+
self._counter[method] += 1
608+
break
609+
self.patcher.stop()
610+
611+
def __getitem__(self, key: str) -> int:
612+
return self._counter[key]
613+
614+
@property
615+
def total_calls(self) -> int:
616+
return sum(self._counter.values())

0 commit comments

Comments
 (0)