Skip to content

Commit 8280654

Browse files
committed
Fix tests
* bump version * update dockerfile * fix progress bar * remove outdated test * rename models
1 parent 8529a93 commit 8280654

File tree

7 files changed

+13
-58
lines changed

7 files changed

+13
-58
lines changed

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ RUN make dist
4040
FROM cebra-base
4141

4242
# install the cebra wheel
43-
ENV WHEEL=cebra-0.4.0-py2.py3-none-any.whl
43+
ENV WHEEL=cebra-0.4.0+regcl-py2.py3-none-any.whl
4444
WORKDIR /build
4545
COPY --from=wheel /build/dist/${WHEEL} .
46-
RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]'
46+
RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets,regcl]'
4747
RUN rm -rf /build
4848

4949
# add the repository

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
CEBRA_VERSION := 0.4.0
1+
CEBRA_VERSION := 0.4.0+regcl
22

33
dist:
44
python3 -m pip install virtualenv

PKGBUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Maintainer: Steffen Schneider <[email protected]>
22
pkgname=python-cebra
33
_pkgname=cebra
4-
pkgver=0.4.0
4+
pkgver=0.4.0+regcl
55
pkgrel=1
66
pkgdesc="Consistent Embeddings of high-dimensional Recordings using Auxiliary variables"
77
url="https://cebra.ai"

cebra/models/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
871871

872872

873873
@register("offset1-model-mse-tanh")
874-
class Offset0ModelMSE(_OffsetModel):
874+
class Offset0ModelMSETanH(_OffsetModel):
875875
"""CEBRA model with a single sample receptive field, without output normalization."""
876876

877877
def __init__(self, num_neurons, num_units, num_output, normalize=False):
@@ -901,7 +901,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
901901
@parametrize("offset1-model-mse-clip-{clip_min}-{clip_max}",
902902
clip_min=(1000, 100, 50, 25, 20, 15, 10, 5, 1),
903903
clip_max=(1000, 100, 50, 25, 20, 15, 10, 5, 1))
904-
class Offset0ModelMSE(_OffsetModel):
904+
class Offset0ModelMSEClip(_OffsetModel):
905905
"""CEBRA model with a single sample receptive field, without output normalization."""
906906

907907
def __init__(self,
@@ -942,7 +942,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
942942
@parametrize("offset1-model-mse-v2-{n_intermediate_layers}layers{tanh}",
943943
n_intermediate_layers=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
944944
tanh=("-tanh", ""))
945-
class Offset0Model(_OffsetModel):
945+
class Offset0ModelMSETanHv2(_OffsetModel):
946946
"""CEBRA model with a single sample receptive field, without output normalization."""
947947

948948
def __init__(self,
@@ -993,7 +993,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
993993
@parametrize("offset1-model-mse-resnet-{n_intermediate_layers}layers{tanh}",
994994
n_intermediate_layers=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
995995
tanh=("-tanh", ""))
996-
class Offset0Model(_OffsetModel):
996+
class Offset0ModelResNetTanH(_OffsetModel):
997997
"""CEBRA model with a single sample receptive field, without output normalization."""
998998

999999
def __init__(self,

cebra/solver/util.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,16 @@ def __post_init__(self):
9090
raise ValueError(
9191
f"log_format must be one of {self._valid_formats}, "
9292
f"but got {self.log_formats}")
93+
self._stats = None
9394

9495
def __iter__(self):
9596
self.iterator = self.loader
9697
if self.use_tqdm:
9798
self.iterator = tqdm.tqdm(self.iterator)
9899
for num_batch, batch in enumerate(self.iterator):
99100
yield num_batch, batch
100-
self._log_message(num_batch, self.iterator.stats)
101-
self._log_message(num_batch, self.iterator.stats)
101+
self._log_message(num_batch, self._stats)
102+
self._log_message(num_batch, self._stats)
102103

103104
def _log_message(self, num_steps, stats):
104105
if self.logger is None:
@@ -119,4 +120,4 @@ def set_description(self, stats: Dict[str, float]):
119120
if self.use_tqdm:
120121
self.iterator.set_description(_description(stats))
121122

122-
self.iterator.stats = stats
123+
self._stats = stats

reinstall.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pip uninstall -y cebra
1515
# Get version info after uninstalling --- this will automatically get the
1616
# most recent version based on the source code in the current directory.
1717
# $(tools/get_cebra_version.sh)
18-
VERSION=0.4.0
18+
VERSION=0.4.0+regcl
1919
echo "Upgrading to CEBRA v${VERSION}"
2020

2121
# Upgrade the build system (PEP517/518 compatible)

tests/test_models.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -88,52 +88,6 @@ def test_offset_models(model_name, batch_size, input_length):
8888
assert len(outputs) == batch_size
8989

9090

91-
def test_multiobjective():
92-
93-
class TestModel(cebra.models.Model):
94-
95-
def __init__(self):
96-
super().__init__(num_input=10, num_output=10)
97-
self._model = nn.Linear(self.num_input, self.num_output)
98-
99-
def forward(self, x):
100-
return self._model(x)
101-
102-
@property
103-
def get_offset(self):
104-
return None
105-
106-
model = TestModel()
107-
108-
multi_model_overlap = cebra.models.MultiobjectiveModel(
109-
model,
110-
dimensions=(4, 6),
111-
output_mode="overlapping",
112-
append_last_dimension=True)
113-
multi_model_separate = cebra.models.MultiobjectiveModel(
114-
model,
115-
dimensions=(4, 6),
116-
output_mode="separate",
117-
append_last_dimension=True)
118-
119-
x = torch.randn(5, 10)
120-
121-
assert model(x).shape == (5, 10)
122-
123-
assert model.num_output == multi_model_overlap.num_output
124-
assert model.get_offset == multi_model_overlap.get_offset
125-
126-
first, second, third = multi_model_overlap(x)
127-
assert first.shape == (5, 4)
128-
assert second.shape == (5, 6)
129-
assert third.shape == (5, 10)
130-
131-
first, second, third = multi_model_separate(x)
132-
assert first.shape == (5, 4)
133-
assert second.shape == (5, 2)
134-
assert third.shape == (5, 4)
135-
136-
13791
@pytest.mark.parametrize("version,raises", [
13892
["1.12", False],
13993
["2.", False],

0 commit comments

Comments
 (0)