-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
62 lines (50 loc) · 1.93 KB
/
utils.py
File metadata and controls
62 lines (50 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from torch import utils
from torchvision import datasets, transforms
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
def get_dataset_mean_std(dataset):
loader = utils.data.DataLoader(
dataset,
batch_size=1000,
num_workers=4,
shuffle=False
)
mean = 0.
std = 0.
nb_samples = 0.
for data, data_y in loader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
nb_samples += batch_samples
mean /= nb_samples
std /= nb_samples
return mean, std
mnist_pre_transform = transforms.Compose([
transforms.ToTensor()
])
def mnist(batch_size=50, shuffle=True, transform=mnist_pre_transform, path='./MNIST_data'):
train_data = datasets.MNIST(path, train=True, download=True, transform=transform)
train_mean, train_std = get_dataset_mean_std(train_data)
print('Train mean and std', train_mean, train_std)
mnist_train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(train_mean, train_std,)
])
# Using same transform for both sets
train_data = datasets.MNIST(path, train=True, download=True, transform=mnist_train_transform)
test_data = datasets.MNIST(path, train=False, download=True, transform=mnist_train_transform)
train_loader = utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle)
test_loader = utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)
return train_loader, test_loader
def plot_mnist(images, shape):
fig = plt.figure(figsize=shape[::-1], dpi=80)
for j in range(1, len(images) + 1):
ax = fig.add_subplot(shape[0], shape[1], j)
ax.matshow(images[j - 1, 0, :, :], cmap = matplotlib.cm.binary)
plt.xticks(np.array([]))
plt.yticks(np.array([]))
plt.show()