Skip to content

Commit 7f9d06b

Browse files
committed
local johan changes
1 parent ba5f173 commit 7f9d06b

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def __init__(
3333
self.mnist_path = self.data_path / "MNIST"
3434
self.samples = sample_ids
3535
self.train = train
36+
self.transform = transform
37+
self.num_classes = 6
3638

3739
self.images_path = self.mnist_path / (
3840
MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]
@@ -46,7 +48,7 @@ def __len__(self):
4648

4749
def __getitem__(self, idx):
4850
with open(self.labels_path, "rb") as labelfile:
49-
label_pos = 8 + self.sample[idx]
51+
label_pos = 8 + self.samples[idx]
5052
labelfile.seek(label_pos)
5153
label = int.from_bytes(labelfile.read(1), byteorder="big")
5254

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def main():
145145
for x, y in tqdm(trainloader, desc="Training"):
146146
x, y = x.to(device), y.to(device)
147147
logits = model.forward(x)
148-
148+
from IPython import embed; embed()
149149
loss = criterion(logits, y)
150150
loss.backward()
151151

0 commit comments

Comments
 (0)