|
26 | 26 | from pkg_resources import DistributionNotFound |
27 | 27 |
|
28 | 28 |
|
| 29 | +def _package_available(package_name: str) -> bool: |
| 30 | + """Check if a package is available in your environment. |
| 31 | +
|
| 32 | + >>> _package_available('os') |
| 33 | + True |
| 34 | + >>> _package_available('bla') |
| 35 | + False |
| 36 | + """ |
| 37 | + try: |
| 38 | + return find_spec(package_name) is not None |
| 39 | + except ModuleNotFoundError: |
| 40 | + return False |
| 41 | + |
| 42 | + |
29 | 43 | def _module_available(module_path: str) -> bool: |
30 | | - """Check if a path is available in your environment. |
| 44 | + """Check if a module path is available in your environment. |
31 | 45 |
|
32 | 46 | >>> _module_available('os') |
33 | 47 | True |
| 48 | + >>> _module_available('os.bla') |
| 49 | + False |
34 | 50 | >>> _module_available('bla.bla') |
35 | 51 | False |
36 | 52 | """ |
| 53 | + module_names = module_path.split(".") |
| 54 | + if not _package_available(module_names[0]): |
| 55 | + return False |
37 | 56 | try: |
38 | | - return find_spec(module_path) is not None |
| 57 | + module = importlib.import_module(module_names[0]) |
39 | 58 | except AttributeError: |
40 | 59 | # Python 3.6 |
41 | 60 | return False |
42 | | - except ModuleNotFoundError: |
43 | | - # Python 3.7+ |
| 61 | + except ImportError: |
44 | 62 | return False |
| 63 | + for name in module_names[1:]: |
| 64 | + if not hasattr(module, name): |
| 65 | + return False |
| 66 | + module = getattr(module, name) |
| 67 | + return True |
45 | 68 |
|
46 | 69 |
|
47 | 70 | def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool: |
@@ -78,25 +101,25 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: |
78 | 101 | # _TORCH_GREATER_EQUAL_DEV_1_11 = _compare_version("torch", operator.ge, "1.11.0", use_base_version=True) |
79 | 102 |
|
80 | 103 | _APEX_AVAILABLE = _module_available("apex.amp") |
81 | | -_DEEPSPEED_AVAILABLE = _module_available("deepspeed") |
| 104 | +_DEEPSPEED_AVAILABLE = _package_available("deepspeed") |
82 | 105 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") |
83 | 106 | _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3") |
84 | 107 | _FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4") |
85 | 108 | _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.group") |
86 | 109 | _HOROVOD_AVAILABLE = _module_available("horovod.torch") |
87 | | -_HYDRA_AVAILABLE = _module_available("hydra") |
| 110 | +_HYDRA_AVAILABLE = _package_available("hydra") |
88 | 111 | _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") |
89 | | -_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") |
| 112 | +_JSONARGPARSE_AVAILABLE = _package_available("jsonargparse") |
90 | 113 | _KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available() |
91 | | -_NEPTUNE_AVAILABLE = _module_available("neptune") |
| 114 | +_NEPTUNE_AVAILABLE = _package_available("neptune") |
92 | 115 | _NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0") |
93 | | -_OMEGACONF_AVAILABLE = _module_available("omegaconf") |
94 | | -_POPTORCH_AVAILABLE = _module_available("poptorch") |
95 | | -_RICH_AVAILABLE = _module_available("rich") and _compare_version("rich", operator.ge, "10.2.2") |
| 116 | +_OMEGACONF_AVAILABLE = _package_available("omegaconf") |
| 117 | +_POPTORCH_AVAILABLE = _package_available("poptorch") |
| 118 | +_RICH_AVAILABLE = _package_available("rich") and _compare_version("rich", operator.ge, "10.2.2") |
96 | 119 | _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"]) |
97 | | -_TORCHTEXT_AVAILABLE = _module_available("torchtext") |
98 | | -_TORCHVISION_AVAILABLE = _module_available("torchvision") |
99 | | -_XLA_AVAILABLE: bool = _module_available("torch_xla") |
| 120 | +_TORCHTEXT_AVAILABLE = _package_available("torchtext") |
| 121 | +_TORCHVISION_AVAILABLE = _package_available("torchvision") |
| 122 | +_XLA_AVAILABLE: bool = _package_available("torch_xla") |
100 | 123 |
|
101 | 124 | from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402 |
102 | 125 |
|
|
0 commit comments