|
17 | 17 | import urllib |
18 | 18 | import warnings |
19 | 19 | import zipfile |
| 20 | +from functools import wraps |
20 | 21 | from typing import List, Union |
21 | 22 |
|
22 | 23 | import numpy as np |
@@ -121,3 +122,40 @@ def get_loader_options(dataset: "cebra.data.Dataset") -> List[str]: |
121 | 122 | "The 'get_loader_options' function has been moved to 'cebra.data.helpers' module. " |
122 | 123 | "Please update your imports.", DeprecationWarning) |
123 | 124 | return cebra.data.helper.get_loader_options |
| 125 | + |
| 126 | + |
| 127 | +def requires_package_version(module, version: str): |
| 128 | + """Decorator to require a minimum version of a package. |
| 129 | +
|
| 130 | + Args: |
| 131 | + module: Module to be checked. |
| 132 | + version: The minimum required version for the module. |
| 133 | +
|
| 134 | + Raises: |
| 135 | + ImportError: If the specified ``module`` version is less than |
| 136 | + the required ``version``. |
| 137 | + """ |
| 138 | + |
| 139 | + required_version = pkg_resources.parse_version(version) |
| 140 | + |
| 141 | + def _requires_package_version(function): |
| 142 | + |
| 143 | + @wraps(function) |
| 144 | + def wrapper(*args, patched_version=None, **kwargs): |
| 145 | + if patched_version != None: |
| 146 | + installed_version = pkg_resources.parse_version( |
| 147 | + patched_version) # Use the patched version if provided |
| 148 | + else: |
| 149 | + installed_version = pkg_resources.parse_version( |
| 150 | + module.__version__) |
| 151 | + |
| 152 | + if installed_version < required_version: |
| 153 | + raise ImportError( |
| 154 | + f"The function '{function.__name__}' requires {module.__name__} " |
| 155 | + f"version {required_version} or higher, but you have {installed_version}. " |
| 156 | + f"Please upgrade {module.__name__}.") |
| 157 | + return function(*args, **kwargs) |
| 158 | + |
| 159 | + return wrapper |
| 160 | + |
| 161 | + return _requires_package_version |
0 commit comments