Skip to content

Commit e201e00

Browse files
thomasmarchioro3thomasmarchioro3
andauthored
Fixed SMoG examples (#1876)
* fixed smog examples * fixed format and autogenerated notebooks * applied suggested changes --------- Co-authored-by: thomasmarchioro3 <thomasmarchioro3@github.com>
1 parent 1d2eb09 commit e201e00

File tree

4 files changed

+11
-7
lines changed

4 files changed

+11
-7
lines changed

examples/notebooks/pytorch/smog.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
"metadata": {},
77
"source": [
88
"This example requires the following dependencies to be installed:\n",
9-
"pip install lightly"
9+
"pip install lightly\n",
10+
"pip install \"scikit-learn>=1.7.1\""
1011
]
1112
},
1213
{
@@ -107,7 +108,7 @@
107108
" def reset_group_features(self, memory_bank):\n",
108109
" # see https://arxiv.org/pdf/2207.06167.pdf Table 7b)\n",
109110
" features = memory_bank.bank\n",
110-
" group_features = self._cluster_features(features.t())\n",
111+
" group_features = self._cluster_features(features)\n",
111112
" self.smog.set_group_features(group_features)\n",
112113
"\n",
113114
" def reset_momentum_weights(self):\n",

examples/notebooks/pytorch_lightning/smog.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"from lightly import loss, models\n",
6464
"from lightly.models import utils\n",
6565
"from lightly.models.modules import heads\n",
66+
"from lightly.models.modules.memory_bank import MemoryBankModule\n",
6667
"from lightly.transforms.smog_transform import SMoGTransform"
6768
]
6869
},
@@ -93,7 +94,7 @@
9394
" # smog\n",
9495
" self.n_groups = 300\n",
9596
" memory_bank_size = 10000\n",
96-
" self.memory_bank = loss.memory_bank.MemoryBankModule(size=memory_bank_size)\n",
97+
" self.memory_bank = MemoryBankModule(size=(memory_bank_size, 128))\n",
9798
" # create our loss\n",
9899
" group_features = torch.nn.functional.normalize(\n",
99100
" torch.rand(self.n_groups, 128), dim=1\n",
@@ -111,7 +112,7 @@
111112
" def _reset_group_features(self):\n",
112113
" # see https://arxiv.org/pdf/2207.06167.pdf Table 7b)\n",
113114
" features = self.memory_bank.bank\n",
114-
" group_features = self._cluster_features(features.t())\n",
115+
" group_features = self._cluster_features(features)\n",
115116
" self.smog.set_group_features(group_features)\n",
116117
"\n",
117118
" def _reset_momentum_weights(self):\n",

examples/pytorch/smog.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# This example requires the following dependencies to be installed:
22
# pip install lightly
3+
# pip install "scikit-learn>=1.7.1"
34

45
# Note: The model and training settings do not follow the reference settings
56
# from the paper. The settings are chosen such that the example can easily be
@@ -53,7 +54,7 @@ def _cluster_features(self, features: torch.Tensor) -> torch.Tensor:
5354
def reset_group_features(self, memory_bank):
5455
# see https://arxiv.org/pdf/2207.06167.pdf Table 7b)
5556
features = memory_bank.bank
56-
group_features = self._cluster_features(features.t())
57+
group_features = self._cluster_features(features)
5758
self.smog.set_group_features(group_features)
5859

5960
def reset_momentum_weights(self):

examples/pytorch_lightning/smog.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from lightly import loss, models
1717
from lightly.models import utils
1818
from lightly.models.modules import heads
19+
from lightly.models.modules.memory_bank import MemoryBankModule
1920
from lightly.transforms.smog_transform import SMoGTransform
2021

2122

@@ -39,7 +40,7 @@ def __init__(self):
3940
# smog
4041
self.n_groups = 300
4142
memory_bank_size = 10000
42-
self.memory_bank = loss.memory_bank.MemoryBankModule(size=memory_bank_size)
43+
self.memory_bank = MemoryBankModule(size=(memory_bank_size, 128))
4344
# create our loss
4445
group_features = torch.nn.functional.normalize(
4546
torch.rand(self.n_groups, 128), dim=1
@@ -57,7 +58,7 @@ def _cluster_features(self, features: torch.Tensor) -> torch.Tensor:
5758
def _reset_group_features(self):
5859
# see https://arxiv.org/pdf/2207.06167.pdf Table 7b)
5960
features = self.memory_bank.bank
60-
group_features = self._cluster_features(features.t())
61+
group_features = self._cluster_features(features)
6162
self.smog.set_group_features(group_features)
6263

6364
def _reset_momentum_weights(self):

0 commit comments

Comments
 (0)