@@ -18,21 +18,21 @@ def data():
1818 batch = _pickle .load (b , encoding = 'latin1' )
1919 datas .append (_np .array (batch ['data' ], dtype = _np .float32 ))
2020 labels .append (_np .array (batch ['labels' ]))
21- Xtr = _np .concatenate (datas )
21+ Xtr = _np .concatenate (datas ). reshape (( - 1 , 3 , 32 , 32 ))
2222 ytr = _np .concatenate (labels )
2323 Xtr /= 255
2424
2525 # ... and the fifth as validation set as described in cuda-convnet:
2626 # https://code.google.com/p/cuda-convnet/wiki/Methodology
2727 with f .extractfile ('cifar-10-batches-py/data_batch_5' ) as b :
2828 batch = _pickle .load (b , encoding = 'latin1' )
29- Xva = _np .array (batch ['data' ], dtype = _np .float32 )
29+ Xva = _np .array (batch ['data' ], dtype = _np .float32 ). reshape (( - 1 , 3 , 32 , 32 ))
3030 yva = _np .array (batch ['labels' ])
3131 Xva /= 255
3232
3333 with f .extractfile ('cifar-10-batches-py/test_batch' ) as b :
3434 batch = _pickle .load (b , encoding = 'latin1' )
35- Xte = _np .array (batch ['data' ], dtype = _np .float32 )
35+ Xte = _np .array (batch ['data' ], dtype = _np .float32 ). reshape (( - 1 , 3 , 32 , 32 ))
3636 yte = _np .array (batch ['labels' ])
3737 Xte /= 255
3838
0 commit comments