Skip to content

Commit 74e29d2

Browse files
committed
Ruff format
1 parent 0fd3cd3 commit 74e29d2

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@ def __init__(
4343
self.labels_path = self.mnist_path / (
4444
MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]
4545
)
46-
47-
# Functions to map the labels from (4,9) -> (0,5) for CrossEntropyLoss to work properly.
48-
self.label_shift = lambda x: x-4
49-
self.label_restore = lambda x: x+4
50-
51-
46+
47+
# Functions to map the labels from (4,9) -> (0,5) for CrossEntropyLoss to work properly.
48+
self.label_shift = lambda x: x - 4
49+
self.label_restore = lambda x: x + 4
50+
5251
def __len__(self):
5352
return len(self.samples)
5453

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main():
139139

140140
for epoch in range(args.epoch):
141141
# Training loop start
142-
print(f"Epoch: {epoch+1}/{args.epoch}")
142+
print(f"Epoch: {epoch + 1}/{args.epoch}")
143143
trainingloss = []
144144
model.train()
145145
for x, y in tqdm(trainloader, desc="Training"):

0 commit comments

Comments
 (0)