File tree Expand file tree Collapse file tree 2 files changed +4
-2
lines changed
CollaborativeCoding/dataloaders Expand file tree Collapse file tree 2 files changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,8 @@ def __init__(
3333 self .mnist_path = self .data_path / "MNIST"
3434 self .samples = sample_ids
3535 self .train = train
36+ self .transform = transform
37+ self .num_classes = 6
3638
3739 self .images_path = self .mnist_path / (
3840 MNIST_SOURCE ["train_images" ][1 ] if train else MNIST_SOURCE ["test_images" ][1 ]
@@ -46,7 +48,7 @@ def __len__(self):
4648
4749 def __getitem__ (self , idx ):
4850 with open (self .labels_path , "rb" ) as labelfile :
49- label_pos = 8 + self .sample [idx ]
51+ label_pos = 8 + self .samples [idx ]
5052 labelfile .seek (label_pos )
5153 label = int .from_bytes (labelfile .read (1 ), byteorder = "big" )
5254
Original file line number Diff line number Diff line change @@ -145,7 +145,7 @@ def main():
145145 for x , y in tqdm (trainloader , desc = "Training" ):
146146 x , y = x .to (device ), y .to (device )
147147 logits = model .forward (x )
148-
148+ from IPython import embed ; embed ()
149149 loss = criterion (logits , y )
150150 loss .backward ()
151151
You can’t perform that action at this time.
0 commit comments