-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataloader.py
More file actions
27 lines (24 loc) · 942 Bytes
/
dataloader.py
File metadata and controls
27 lines (24 loc) · 942 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import mnist_web
import numpy as np
import random
import torch
import sys
class DataLoader():
def __init__(self, train = True, cuda = False):
if train:
self.images, self.labels, _, _= mnist_web.mnist(path='.')
else:
_, _, self.images, self.labels = mnist_web.mnist(path='.')
self.images *= 255
self.images = self.images.astype('int32')
#self.labels = np.sum(self.labels * np.arange(0,10),1).reshape(-1,1)
self.labels = self.labels.astype('int32')
self.images[self.images<=128] = -1
self.images[self.images>128] = 1
self.images = torch.from_numpy(self.images)
self.labels = torch.from_numpy(self.labels)
if cuda:
self.images = self.images.float().cuda()
self.labels = self.labels.float().cuda()
def get_all(self):
return self.images,self.labels