Skip to content

Commit 3f06b71

Browse files
authored
Enhance ModuleNotFoundError messages (#85)
* Enhance the module not found err messages in cebra.load * Moved the monkey reaching error message * Raise error except warning for nlb_tools * Seed the criterion tests
1 parent 8520b9b commit 3f06b71

File tree

4 files changed

+76
-31
lines changed

4 files changed

+76
-31
lines changed

cebra/data/load.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@
8181
)
8282

8383

84+
def _module_not_found_error(module_name):
85+
return ModuleNotFoundError(
86+
f"Could not load {module_name}. You can manually install {module_name} "
87+
"or install the [datasets] dependency in cebra: "
88+
"pip install 'cebra[datasets]'")
89+
90+
8491
class _BaseLoader(abc.ABC):
8592
"""Base loader."""
8693

@@ -186,7 +193,13 @@ def load(
186193
keypoints=columns)
187194
else:
188195
raise ModuleNotFoundError(
189-
"DLC integration could not be loaded.")
196+
"DLC integration could not be loaded. "
197+
"Most likely, this is because you do not have all "
198+
"integrations dependencies installed. Try installing "
199+
"cebra with the [integrations] and [datasets] dependency to fix this "
200+
"error. You might need to re-start your environment "
201+
"after installing: "
202+
"pip install 'cebra[integrations,datasets]'.")
190203
# if the provided key is valid
191204
elif key in df_keys:
192205
loaded_array = _PandasLoader.load_from_h5(
@@ -208,7 +221,7 @@ def load(
208221
raise AttributeError(
209222
"No valid data structure was found in your file.")
210223
else:
211-
raise ModuleNotFoundError()
224+
raise _module_not_found_error("h5py")
212225
return loaded_array
213226

214227
@staticmethod
@@ -385,7 +398,7 @@ def load(
385398
except pd.errors.EmptyDataError:
386399
raise AttributeError(".csv file is empty.")
387400
else:
388-
raise ModuleNotFoundError()
401+
raise _module_not_found_error("pandas")
389402
return loaded_array
390403

391404

@@ -420,7 +433,7 @@ def load(
420433
loaded_array = loaded_dict[key].values
421434
break
422435
else:
423-
raise ModuleNotFoundError()
436+
raise _module_not_found_error("pandas")
424437
return loaded_array
425438

426439
# def prepare_engine(extension: str):
@@ -486,7 +499,7 @@ def load(
486499
raise NotImplementedError(
487500
f"{type(loaded_data)} is not handled for .jl files.")
488501
else:
489-
raise ModuleNotFoundError()
502+
raise _module_not_found_error("joblib")
490503
return loaded_array
491504

492505

@@ -531,7 +544,7 @@ def load(
531544
raise NotImplementedError(
532545
f"{type(loaded_data)} is not handled for .pk files.")
533546
else:
534-
raise ModuleNotFoundError()
547+
raise _module_not_found_error("pickle")
535548
return loaded_array
536549

537550

@@ -572,7 +585,7 @@ def load(
572585
if _IS_H5PY_AVAILABLE:
573586
loaded_array = _H5pyLoader.load(file, key)
574587
else:
575-
raise ModuleNotFoundError()
588+
raise _module_not_found_error("h5py")
576589
return loaded_array
577590

578591

cebra/datasets/monkey_reaching.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,6 @@
2727
import scipy.io
2828
import torch
2929

30-
try:
31-
from nlb_tools.nwb_interface import NWBDataset
32-
except ImportError:
33-
import warnings
34-
35-
warnings.warn(
36-
("Could not import the nlb_tools package required for data loading "
37-
"of cebra.datasets.monkey_reaching. Dataset will not be available. "
38-
"If required, you can install the dataset by running "
39-
"pip install git+https://github.com/neurallatents/nlb_tools."))
40-
4130
import cebra.data
4231
from cebra.datasets import get_datapath
4332
from cebra.datasets import register
@@ -62,6 +51,16 @@ def _load_data(
6251
6352
"""
6453

54+
try:
55+
from nlb_tools.nwb_interface import NWBDataset
56+
except ImportError as e:
57+
raise ImportError(
58+
"Could not import the nlb_tools package required for data loading "
59+
"the raw reaching datasets in NWB format. "
60+
"If required, you can install the dataset by running "
61+
"pip install nlb_tools or installing cebra with the [datasets] "
62+
"dependencies: pip install 'cebra[datasets]'")
63+
6564
def _get_info(trial_info, data):
6665
passive = []
6766
direction = []

tests/test_criterions.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,16 @@ def _compute_grads(output, inputs):
276276
return [input_.grad for input_ in inputs]
277277

278278

279-
def test_infonce():
279+
def _sample_dist_matrices(seed):
280+
rng = torch.Generator().manual_seed(42)
281+
pos_dist = torch.randn(100, generator=rng)
282+
neg_dist = torch.randn(100, 100, generator=rng)
283+
return pos_dist, neg_dist
284+
280285

281-
pos_dist = torch.randn(100,)
282-
neg_dist = torch.randn(100, 100)
286+
@pytest.mark.parametrize("seed", [42, 4242, 424242])
287+
def test_infonce(seed):
288+
pos_dist, neg_dist = _sample_dist_matrices(seed)
283289

284290
ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist)
285291
loss, align, uniform = cebra_criterions.infonce(pos_dist, neg_dist)
@@ -290,11 +296,9 @@ def test_infonce():
290296
assert torch.allclose(align + uniform, loss)
291297

292298

293-
def test_infonce_gradients():
294-
295-
rng = torch.Generator().manual_seed(42)
296-
pos_dist = torch.randn(100, generator=rng)
297-
neg_dist = torch.randn(100, 100, generator=rng)
299+
@pytest.mark.parametrize("seed", [42, 4242, 424242])
300+
def test_infonce_gradients(seed):
301+
pos_dist, neg_dist = _sample_dist_matrices(seed)
298302

299303
for i in range(3):
300304
pos_dist_ = pos_dist.clone()
@@ -312,7 +316,7 @@ def test_infonce_gradients():
312316
grad = _compute_grads(loss, [pos_dist_, neg_dist_])
313317

314318
# NOTE(stes) default relative tolerance is 1e-5
315-
assert torch.allclose(loss_ref, loss, rtol = 1e-4)
319+
assert torch.allclose(loss_ref, loss, rtol=1e-4)
316320

317321
if i == 0:
318322
assert grad[0] is not None

tests/test_load.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import pathlib
1313
import pickle
1414
import tempfile
15+
import unittest
16+
from unittest.mock import patch
1517

1618
import h5py
1719
import hdf5storage
@@ -27,6 +29,7 @@
2729

2830
__test_functions = []
2931
__test_functions_error = []
32+
__test_functions_module_not_found = []
3033

3134

3235
def _skip_hdf5storage(*args, **kwargs):
@@ -42,7 +45,7 @@ def test_imports():
4245
assert hasattr(cebra, "load_data")
4346

4447

45-
def register(*file_endings):
48+
def register(*file_endings, requires=()):
4649
# for each file format
4750
def _register(f):
4851
# f is the filename
@@ -53,6 +56,12 @@ def _register(f):
5356
lambda filename: f(filename + "." + file_ending)
5457
for file_ending in file_endings
5558
])
59+
if len(requires) > 0:
60+
__test_functions_module_not_found.extend([
61+
(requires, lambda filename: filename + "." + file_ending,
62+
lambda filename: f(filename + "." + file_ending))
63+
for file_ending in file_endings
64+
])
5665
return f
5766

5867
return _register
@@ -152,7 +161,7 @@ def generate_numpy_no_array(filename):
152161
# TODO: test raise ModuleFoundError for h5py
153162

154163

155-
@register("h5", "hdf", "hdf5", "h")
164+
@register("h5", "hdf", "hdf5", "h", requires=("h5py",))
156165
def generate_h5(filename):
157166
A = np.arange(1000).reshape(10, 100)
158167
with h5py.File(filename, "w") as hf:
@@ -380,7 +389,7 @@ def generate_wrong_key(filename):
380389

381390

382391
#### .CSV ####
383-
@register("csv")
392+
@register("csv", requires=("pandas",))
384393
def generate_csv(filename):
385394
A = np.arange(1000).reshape(10, 100)
386395
pd.DataFrame(A).to_csv(filename, header=False, index=False, sep=",")
@@ -404,7 +413,7 @@ def generate_csv_empty_file(filename):
404413

405414

406415
#### EXCEL ####
407-
@register("xls", "xlsx", "xlsm")
416+
@register("xls", "xlsx", "xlsm", requires=("pandas", "pd"))
408417
# TODO(celia): add the following extension: "xlsb", "odf", "ods", "odt",
409418
# issue to create the files
410419
def generate_excel(filename):
@@ -777,3 +786,23 @@ def test_load_error(save_data):
777786

778787
with pytest.raises((AttributeError, TypeError)):
779788
save_data(filename)
789+
790+
791+
@pytest.mark.parametrize("module_names,get_path,save_data",
792+
__test_functions_module_not_found)
793+
def test_module_not_installed(module_names, get_path, save_data):
794+
795+
assert len(module_names) > 0
796+
assert isinstance(module_names, tuple)
797+
798+
with tempfile.NamedTemporaryFile() as tf:
799+
filename = tf.name
800+
801+
saved_array, loaded_array = save_data(filename)
802+
assert np.allclose(saved_array, loaded_array)
803+
804+
# TODO(stes): Sketch for a test --- needs additional work.
805+
# with patch.dict('sys.modules', {module: None for module in module_names}):
806+
# path = get_path(filename)
807+
# with pytest.raises(ModuleNotFoundError, match="cebra[datasets]"):
808+
# cebra.data.load.load(path)

0 commit comments

Comments
 (0)