Skip to content

Commit a21ba0e

Browse files
authored
Add Matplotlib ImportError when using compare_models() (#53)
* raise ImportError * Update matplotlib.py - added note to explain this is min. * parsing version + add test * choose option 1 + fix test * add decorator and change test * Add docstring
1 parent 6fe5f19 commit a21ba0e

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

cebra/helper.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import urllib
1818
import warnings
1919
import zipfile
20+
from functools import wraps
2021
from typing import List, Union
2122

2223
import numpy as np
@@ -121,3 +122,40 @@ def get_loader_options(dataset: "cebra.data.Dataset") -> List[str]:
121122
"The 'get_loader_options' function has been moved to 'cebra.data.helpers' module. "
122123
"Please update your imports.", DeprecationWarning)
123124
return cebra.data.helper.get_loader_options
125+
126+
127+
def requires_package_version(module, version: str):
128+
"""Decorator to require a minimum version of a package.
129+
130+
Args:
131+
module: Module to be checked.
132+
version: The minimum required version for the module.
133+
134+
Raises:
135+
ImportError: If the specified ``module`` version is less than
136+
the required ``version``.
137+
"""
138+
139+
required_version = pkg_resources.parse_version(version)
140+
141+
def _requires_package_version(function):
142+
143+
@wraps(function)
144+
def wrapper(*args, patched_version=None, **kwargs):
145+
if patched_version != None:
146+
installed_version = pkg_resources.parse_version(
147+
patched_version) # Use the patched version if provided
148+
else:
149+
installed_version = pkg_resources.parse_version(
150+
module.__version__)
151+
152+
if installed_version < required_version:
153+
raise ImportError(
154+
f"The function '{function.__name__}' requires {module.__name__} "
155+
f"version {required_version} or higher, but you have {installed_version}. "
156+
f"Please upgrade {module.__name__}.")
157+
return function(*args, **kwargs)
158+
159+
return wrapper
160+
161+
return _requires_package_version

cebra/integrations/matplotlib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,10 @@ def plot_consistency(
11321132
).plot(**kwargs)
11331133

11341134

1135+
from cebra.helper import requires_package_version
1136+
1137+
1138+
@requires_package_version(matplotlib, "3.6")
11351139
def compare_models(
11361140
models: List[CEBRA],
11371141
labels: Optional[List[str]] = None,
@@ -1185,8 +1189,10 @@ def compare_models(
11851189
The axis of the generated plot. If no ``ax`` argument was specified, it will be created
11861190
by the function and returned here.
11871191
"""
1192+
11881193
if not isinstance(models, list):
11891194
raise ValueError(f"Invalid list of models, got {type(models)}.")
1195+
11901196
for model in models:
11911197
if not isinstance(model, CEBRA):
11921198
raise ValueError(

tests/test_plot.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import matplotlib
1515
import matplotlib.pyplot as plt
1616
import numpy as np
17+
import pkg_resources
1718
import pytest
1819
import torch
1920
from sklearn.exceptions import NotFittedError
@@ -143,10 +144,32 @@ def test_plot_loss():
143144
plt.close()
144145

145146

147+
@pytest.mark.parametrize("matplotlib_version",
148+
["3.3", "3.4.2", "3.5", "3.6", "3.7"])
149+
def test_compare_models_with_different_versions(matplotlib_version):
150+
# example dataset
151+
X = np.random.uniform(0, 1, (1000, 2))
152+
n_models = 2
153+
154+
fitted_models = []
155+
for _ in range(n_models):
156+
fitted_models.append(
157+
cebra_sklearn_cebra.CEBRA(max_iterations=10, batch_size=128).fit(X))
158+
159+
# minimum version of matplotlib
160+
minimum_version = "3.6"
161+
162+
if pkg_resources.parse_version(
163+
matplotlib_version) < pkg_resources.parse_version(minimum_version):
164+
with pytest.raises(ImportError):
165+
cebra_plot.compare_models(models=fitted_models,
166+
patched_version=matplotlib_version)
167+
168+
146169
def test_compare_models():
147170
# example dataset
148-
X = np.random.uniform(0, 1, (1000, 50))
149-
n_models = 10
171+
X = np.random.uniform(0, 1, (100, 5))
172+
n_models = 4
150173

151174
fig = plt.figure(figsize=(5, 5))
152175
ax = fig.add_subplot()

0 commit comments

Comments
 (0)