Skip to content

Commit e3c6d40

Browse files
committed
Shape CIFAR into an image, like MNIST and SVHN.
1 parent 72eac47 commit e3c6d40

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

DeepFried2/datasets/cifar10.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,21 @@ def data():
1818
batch = _pickle.load(b, encoding='latin1')
1919
datas.append(_np.array(batch['data'], dtype=_np.float32))
2020
labels.append(_np.array(batch['labels']))
21-
Xtr = _np.concatenate(datas)
21+
Xtr = _np.concatenate(datas).reshape((-1, 3, 32, 32))
2222
ytr = _np.concatenate(labels)
2323
Xtr /= 255
2424

2525
# ... and the fifth as validation set as described in cuda-convnet:
2626
# https://code.google.com/p/cuda-convnet/wiki/Methodology
2727
with f.extractfile('cifar-10-batches-py/data_batch_5') as b:
2828
batch = _pickle.load(b, encoding='latin1')
29-
Xva = _np.array(batch['data'], dtype=_np.float32)
29+
Xva = _np.array(batch['data'], dtype=_np.float32).reshape((-1, 3, 32, 32))
3030
yva = _np.array(batch['labels'])
3131
Xva /= 255
3232

3333
with f.extractfile('cifar-10-batches-py/test_batch') as b:
3434
batch = _pickle.load(b, encoding='latin1')
35-
Xte = _np.array(batch['data'], dtype=_np.float32)
35+
Xte = _np.array(batch['data'], dtype=_np.float32).reshape((-1, 3, 32, 32))
3636
yte = _np.array(batch['labels'])
3737
Xte /= 255
3838

DeepFried2/datasets/cifar100.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def data():
1313
with _taropen(fname, 'r') as f:
1414
with f.extractfile('cifar-100-python/train') as train:
1515
train = _pickle.load(train, encoding='latin1')
16-
Xtr = _np.array(train['data'], dtype=_np.float32)
16+
Xtr = _np.array(train['data'], dtype=_np.float32).reshape((-1, 3, 32, 32))
1717
ytr_c = _np.array(train['coarse_labels'])
1818
ytr_f = _np.array(train['fine_labels'])
1919
Xtr /= 255
@@ -27,7 +27,7 @@ def data():
2727

2828
with f.extractfile('cifar-100-python/test') as test:
2929
test = _pickle.load(test, encoding='latin1')
30-
Xte = _np.array(test['data'], dtype=_np.float32)
30+
Xte = _np.array(test['data'], dtype=_np.float32).reshape((-1, 3, 32, 32))
3131
yte_c = _np.array(test['coarse_labels'])
3232
yte_f = _np.array(test['fine_labels'])
3333
Xte /= 255

0 commit comments

Comments
 (0)