|
26 | 26 | from pkg_resources import DistributionNotFound
|
27 | 27 |
|
28 | 28 |
|
| 29 | +def _package_available(package_name: str) -> bool: |
| 30 | + """Check if a package is available in your environment. |
| 31 | +
|
| 32 | + >>> _package_available('os') |
| 33 | + True |
| 34 | + >>> _package_available('bla') |
| 35 | + False |
| 36 | + """ |
| 37 | + try: |
| 38 | + return find_spec(package_name) is not None |
| 39 | + except ModuleNotFoundError: |
| 40 | + return False |
| 41 | + |
| 42 | + |
29 | 43 | def _module_available(module_path: str) -> bool:
|
30 |
| - """Check if a path is available in your environment. |
| 44 | + """Check if a module path is available in your environment. |
31 | 45 |
|
32 | 46 | >>> _module_available('os')
|
33 | 47 | True
|
| 48 | + >>> _module_available('os.bla') |
| 49 | + False |
34 | 50 | >>> _module_available('bla.bla')
|
35 | 51 | False
|
36 | 52 | """
|
| 53 | + module_names = module_path.split(".") |
| 54 | + if not _package_available(module_names[0]): |
| 55 | + return False |
37 | 56 | try:
|
38 |
| - return find_spec(module_path) is not None |
| 57 | + module = importlib.import_module(module_names[0]) |
39 | 58 | except AttributeError:
|
40 | 59 | # Python 3.6
|
41 | 60 | return False
|
42 |
| - except ModuleNotFoundError: |
43 |
| - # Python 3.7+ |
| 61 | + except ImportError: |
44 | 62 | return False
|
| 63 | + for name in module_names[1:]: |
| 64 | + if not hasattr(module, name): |
| 65 | + return False |
| 66 | + module = getattr(module, name) |
| 67 | + return True |
45 | 68 |
|
46 | 69 |
|
47 | 70 | def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
|
@@ -78,25 +101,25 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
|
78 | 101 | # _TORCH_GREATER_EQUAL_DEV_1_11 = _compare_version("torch", operator.ge, "1.11.0", use_base_version=True)
|
79 | 102 |
|
80 | 103 | _APEX_AVAILABLE = _module_available("apex.amp")
|
81 |
| -_DEEPSPEED_AVAILABLE = _module_available("deepspeed") |
| 104 | +_DEEPSPEED_AVAILABLE = _package_available("deepspeed") |
82 | 105 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn")
|
83 | 106 | _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3")
|
84 | 107 | _FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4")
|
85 | 108 | _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.group")
|
86 | 109 | _HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
87 |
| -_HYDRA_AVAILABLE = _module_available("hydra") |
| 110 | +_HYDRA_AVAILABLE = _package_available("hydra") |
88 | 111 | _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
|
89 |
| -_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") |
| 112 | +_JSONARGPARSE_AVAILABLE = _package_available("jsonargparse") |
90 | 113 | _KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available()
|
91 |
| -_NEPTUNE_AVAILABLE = _module_available("neptune") |
| 114 | +_NEPTUNE_AVAILABLE = _package_available("neptune") |
92 | 115 | _NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0")
|
93 |
| -_OMEGACONF_AVAILABLE = _module_available("omegaconf") |
94 |
| -_POPTORCH_AVAILABLE = _module_available("poptorch") |
95 |
| -_RICH_AVAILABLE = _module_available("rich") and _compare_version("rich", operator.ge, "10.2.2") |
| 116 | +_OMEGACONF_AVAILABLE = _package_available("omegaconf") |
| 117 | +_POPTORCH_AVAILABLE = _package_available("poptorch") |
| 118 | +_RICH_AVAILABLE = _package_available("rich") and _compare_version("rich", operator.ge, "10.2.2") |
96 | 119 | _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
|
97 |
| -_TORCHTEXT_AVAILABLE = _module_available("torchtext") |
98 |
| -_TORCHVISION_AVAILABLE = _module_available("torchvision") |
99 |
| -_XLA_AVAILABLE: bool = _module_available("torch_xla") |
| 120 | +_TORCHTEXT_AVAILABLE = _package_available("torchtext") |
| 121 | +_TORCHVISION_AVAILABLE = _package_available("torchvision") |
| 122 | +_XLA_AVAILABLE: bool = _package_available("torch_xla") |
100 | 123 |
|
101 | 124 | from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402
|
102 | 125 |
|
|
0 commit comments