2424import json
2525import os
2626
27- # from keras import backend as k
28- # from keras.datasets.cifar import load_batch
29- # from keras.preprocessing import image
30- # from keras.utils import data_utils
31- # from keras.utils.data_utils import get_file
3227import numpy as np
3328
3429
@@ -125,10 +120,11 @@ def preprocess(x, y, nb_classes=10, max_value=255):
125120def load_cifar10 ():
126121 """Loads CIFAR10 dataset from config.CIFAR10_PATH or downloads it if necessary.
127122
128- :return: (x_train, y_train), (x_test, y_test), min, max
129- :rtype: (tuple of numpy .ndarray), (tuple of numpy .ndarray), float, float
123+ :return: ` (x_train, y_train), (x_test, y_test), min, max`
124+ :rtype: `(np.ndarray, np .ndarray), (np.ndarray, np .ndarray), float, float`
130125 """
131126 from config import CIFAR10_PATH
127+ import keras .backend as k
132128 from keras .datasets .cifar import load_batch
133129 from keras .utils .data_utils import get_file
134130
@@ -166,8 +162,8 @@ def load_cifar10():
166162def load_mnist ():
167163 """Loads MNIST dataset from config.MNIST_PATH or downloads it if necessary.
168164
169- :return: (x_train, y_train), (x_test, y_test), min, max
170- :rtype: tuple of numpy .ndarray), (tuple of numpy .ndarray), float, float
165+ :return: ` (x_train, y_train), (x_test, y_test), min, max`
166+ :rtype: `(np.ndarray, np .ndarray), (np.ndarray, np .ndarray), float, float`
171167 """
172168 from config import MNIST_PATH
173169 from keras .utils .data_utils import get_file
@@ -195,8 +191,8 @@ def load_mnist():
195191def load_imagenet ():
196192 """Loads Imagenet dataset from config.IMAGENET_PATH
197193
198- :return: (x_train, y_train), (x_test, y_test), min, max
199- :rtype: tuple of numpy .ndarray), (tuple of numpy .ndarray), float, float
194+ :return: ` (x_train, y_train), (x_test, y_test), min, max`
195+ :rtype: `(np.ndarray, np .ndarray), (np.ndarray, np .ndarray), float, float`
200196 """
201197 from config import IMAGENET_PATH
202198 from keras .preprocessing import image
@@ -238,10 +234,13 @@ def load_imagenet():
238234def load_stl ():
239235 """Loads the STL-10 dataset from config.STL10_PATH or downloads it if necessary.
240236
241- :return: (x_train, y_train), (x_test, y_test), min, max
242- :rtype: tuple of numpy .ndarray), (tuple of numpy .ndarray), float, float
237+ :return: ` (x_train, y_train), (x_test, y_test), min, max`
238+ :rtype: `(np.ndarray, np .ndarray), (np.ndarray, np .ndarray), float, float`
243239 """
240+ from os .path import join
241+
244242 from config import STL10_PATH
243+ import keras .backend as k
245244 from keras .utils .data_utils import get_file
246245
247246 min_ , max_ = 0. , 1.
@@ -250,23 +249,23 @@ def load_stl():
250249 path = get_file ('stl10_binary' , cache_subdir = STL10_PATH , untar = True ,
251250 origin = 'https://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz' )
252251
253- with open (os . path . join (path , 'train_X.bin' ), 'rb' ) as f :
252+ with open (join (path , str ( 'train_X.bin' )), str ( 'rb' ) ) as f :
254253 x_train = np .fromfile (f , dtype = np .uint8 )
255254 x_train = np .reshape (x_train , (- 1 , 3 , 96 , 96 ))
256255
257- with open (os . path . join (path , 'test_X.bin' ), 'rb' ) as f :
256+ with open (join (path , str ( 'test_X.bin' )), str ( 'rb' ) ) as f :
258257 x_test = np .fromfile (f , dtype = np .uint8 )
259258 x_test = np .reshape (x_test , (- 1 , 3 , 96 , 96 ))
260259
261260 if k .image_data_format () == 'channels_last' :
262261 x_train = x_train .transpose (0 , 2 , 3 , 1 )
263262 x_test = x_test .transpose (0 , 2 , 3 , 1 )
264263
265- with open (os . path . join (path , 'train_y.bin' ), 'rb' ) as f :
264+ with open (join (path , str ( 'train_y.bin' )), str ( 'rb' ) ) as f :
266265 y_train = np .fromfile (f , dtype = np .uint8 )
267266 y_train -= 1
268267
269- with open (os . path . join (path , 'test_y.bin' ), 'rb' ) as f :
268+ with open (join (path , str ( 'test_y.bin' )), str ( 'rb' ) ) as f :
270269 y_test = np .fromfile (f , dtype = np .uint8 )
271270 y_test -= 1
272271
@@ -282,8 +281,8 @@ def load_dataset(name):
282281
283282 :param name: Name of the dataset
284283 :type name: `str`
285- :return: The dataset separated in training and test sets as `(x_train, y_train), (x_test, y_test)`
286- :rtype: `tuple `
284+ :return: The dataset separated in training and test sets as `(x_train, y_train), (x_test, y_test), min, max `
285+ :rtype: `(np.ndarray, np.ndarray), (np.ndarray, np.ndarray), float, float `
287286 :raises NotImplementedError: If the dataset is unknown.
288287 """
289288
0 commit comments