|
7 | 7 | @pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0")
|
8 | 8 | @pytest.mark.parametrize(
|
9 | 9 | "subset",
|
10 |
| - [COMMON_MODELS] |
11 |
| - if IS_FAST_TEST |
12 |
| - else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"], |
| 10 | + ( |
| 11 | + [COMMON_MODELS] |
| 12 | + if IS_FAST_TEST |
| 13 | + else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"] |
| 14 | + ), |
13 | 15 | )
|
14 | 16 | @pytest.mark.parametrize("default_device", ["cpu", "cuda"])
|
15 |
| -def test_torchvision_models_lazy_init(subset, default_device): |
| 17 | +def test_models_lazy_init(subset, default_device): |
16 | 18 | sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True)
|
17 | 19 | for name, entry in sub_model_zoo.items():
|
18 | 20 | # TODO(ver217): lazy init does not support weight norm, skip these models
|
19 | 21 | if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(
|
20 |
| - ("transformers_vit", "transformers_blip2") |
| 22 | + ("transformers_vit", "transformers_blip2", "transformers_whisper") |
21 | 23 | ):
|
22 | 24 | continue
|
23 | 25 | check_lazy_init(entry, verbose=True, default_device=default_device)
|
24 | 26 |
|
25 | 27 |
|
26 | 28 | if __name__ == "__main__":
|
27 |
| - test_torchvision_models_lazy_init("transformers", "cpu") |
| 29 | + test_models_lazy_init("transformers", "cpu") |
0 commit comments