Skip to content

Commit 8b9bb1a

Browse files
gonlairoMMathisLabstes
authored
Fix windows tests (#90)
* fix test_load adding explicit int types * update build.yml with windows * os independent paths in code examples * add doctest skip when saving test file * mute output with ; instead of with doctest skip * capture stdout * parametrize dtype * improve skip condition * improve skipping test messager * change names from "foo.pt" to "cebra.pt" * remove files after example is completed with path.unlink() * fit and save a model so the temporary file exists --------- Co-authored-by: Mackenzie Mathis <[email protected]> Co-authored-by: Steffen Schneider <[email protected]>
1 parent f45e69d commit 8b9bb1a

File tree

5 files changed

+210
-169
lines changed

5 files changed

+210
-169
lines changed

.github/workflows/build.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@ jobs:
2323
- os: ubuntu-latest
2424
python-version: 3.8
2525
torch-version: 1.9.0
26-
# TODO(stes): Include at a later stage
27-
#- os: windows-latest
28-
# torch-version: 2.0.0
29-
# python-version: "3.10"
26+
- os: windows-latest
27+
torch-version: 2.0.0
28+
python-version: "3.10"
3029
#- os: macos-latest
3130
# torch-version: 2.0.0
3231
# python-version: "3.10"

cebra/integrations/sklearn/cebra.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,14 +1161,18 @@ def fit(
11611161
11621162
>>> import cebra
11631163
>>> import numpy as np
1164+
>>> import tempfile
1165+
>>> from pathlib import Path
1166+
>>> tmp_file = Path(tempfile.gettempdir(), 'cebra.pt')
11641167
>>> dataset = np.random.uniform(0, 1, (1000, 20))
11651168
>>> dataset2 = np.random.uniform(0, 1, (1000, 40))
11661169
>>> cebra_model = cebra.CEBRA(max_iterations=10)
11671170
>>> cebra_model.fit(dataset)
11681171
CEBRA(max_iterations=10)
1169-
>>> cebra_model.save('/tmp/foo.pt')
1172+
>>> cebra_model.save(tmp_file)
11701173
>>> cebra_model.fit(dataset2, adapt=True)
11711174
CEBRA(max_iterations=10)
1175+
>>> tmp_file.unlink()
11721176
"""
11731177
if adapt and sklearn_utils.check_fitted(self):
11741178
self._adapt_fit(X,
@@ -1332,11 +1336,15 @@ def save(self,
13321336
13331337
>>> import cebra
13341338
>>> import numpy as np
1339+
>>> import tempfile
1340+
>>> from pathlib import Path
1341+
>>> tmp_file = Path(tempfile.gettempdir(), 'test.jl')
13351342
>>> dataset = np.random.uniform(0, 1, (1000, 30))
13361343
>>> cebra_model = cebra.CEBRA(max_iterations=10)
13371344
>>> cebra_model.fit(dataset)
13381345
CEBRA(max_iterations=10)
1339-
>>> cebra_model.save('/tmp/foo.pt')
1346+
>>> cebra_model.save(tmp_file)
1347+
>>> tmp_file.unlink()
13401348
13411349
"""
13421350
if sklearn_utils.check_fitted(self):
@@ -1394,10 +1402,18 @@ def load(cls,
13941402
Example:
13951403
13961404
>>> import cebra
1397-
>>> import numpy as np
1405+
>>> import numpy as np
1406+
>>> import tempfile
1407+
>>> from pathlib import Path
1408+
>>> tmp_file = Path(tempfile.gettempdir(), 'cebra.pt')
13981409
>>> dataset = np.random.uniform(0, 1, (1000, 20))
1399-
>>> loaded_model = cebra.CEBRA.load('/tmp/foo.pt')
1410+
>>> cebra_model = cebra.CEBRA(max_iterations=10)
1411+
>>> cebra_model.fit(dataset)
1412+
CEBRA(max_iterations=10)
1413+
>>> cebra_model.save(tmp_file)
1414+
>>> loaded_model = cebra.CEBRA.load(tmp_file)
14001415
>>> embedding = loaded_model.transform(dataset)
1416+
>>> tmp_file.unlink()
14011417
14021418
"""
14031419

cebra/io.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,11 @@ class FileKeyValueDataset:
214214
215215
>>> import cebra.io
216216
>>> import joblib
217-
>>> joblib.dump({'foo' : 42}, '/tmp/test.jl')
218-
['/tmp/test.jl']
219-
>>> data = cebra.io.FileKeyValueDataset('/tmp/test.jl')
217+
>>> import tempfile
218+
>>> from pathlib import Path
219+
>>> tmp_file = Path(tempfile.gettempdir(), 'test.jl')
220+
>>> _ = joblib.dump({'foo' : 42}, tmp_file)
221+
>>> data = cebra.io.FileKeyValueDataset(tmp_file)
220222
>>> data.foo
221223
42
222224
@@ -242,14 +244,14 @@ def __repr__(self):
242244
return f"{type(self).__name__}(keys=(\n {sizes}\n))"
243245

244246
def _iterate_items(self):
245-
extension = self.path.split(".")[-1]
246-
if extension in ["jl", "joblib"]:
247+
extension = self.path.suffix
248+
if extension in [".jl", ".joblib"]:
247249
dataset = joblib.load(self.path)
248-
elif extension in ["h5", "hdf", "hdf5"]:
250+
elif extension in [".h5", ".hdf", ".hdf5"]:
249251
raise NotImplementedError()
250-
elif extension in ["pth", "pt"]:
252+
elif extension in [".pth", ".pt"]:
251253
dataset = torch.load(self.path)
252-
elif extension in ["npz"]:
254+
elif extension in [".npz"]:
253255
dataset = np.load(self.path, allow_pickle=True)
254256
else:
255257
raise ValueError(f"Invalid file format: {extension} in {self.path}")

docs/source/usage.rst

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,14 +494,20 @@ The model will be saved as a ``.pt`` file.
494494

495495
.. testcode::
496496

497+
import tempfile
498+
from pathlib import Path
499+
500+
# create temporary file to save the model
501+
tmp_file = Path(tempfile.gettempdir(), 'cebra.pt')
502+
497503
cebra_model = cebra.CEBRA(max_iterations=10)
498504
cebra_model.fit(neural_data)
499505

500506
# Save the model
501-
cebra_model.save('/tmp/foo.pt')
507+
cebra_model.save(tmp_file)
502508

503509
# New session: load and use the model
504-
loaded_cebra_model = cebra.CEBRA.load('/tmp/foo.pt')
510+
loaded_cebra_model = cebra.CEBRA.load(tmp_file)
505511
embedding = loaded_cebra_model.transform(neural_data)
506512

507513

@@ -1221,10 +1227,11 @@ Putting all previous snippet examples together, we obtain the following pipeline
12211227
cebra_model.fit(train_data, train_discrete_label, train_continuous_label)
12221228

12231229
# 5. Save the model
1224-
cebra_model.save('/tmp/foo.pt')
1230+
tmp_file = Path(tempfile.gettempdir(), 'cebra.pt')
1231+
cebra_model.save(tmp_file)
12251232

12261233
# 6. Load the model and compute an embedding
1227-
cebra_model = cebra.CEBRA.load('/tmp/foo.pt')
1234+
cebra_model = cebra.CEBRA.load(tmp_file)
12281235
train_embedding = cebra_model.transform(train_data)
12291236
valid_embedding = cebra_model.transform(valid_data)
12301237
assert train_embedding.shape == (70, 8)

0 commit comments

Comments
 (0)