Skip to content

Commit dbf8b30

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Correct imports in utils
(cherry picked from commit 3e3def6) (cherry picked from commit 4e8658b)
1 parent c041874 commit dbf8b30

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

art/utils.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,6 @@
2424
import json
2525
import os
2626

27-
# from keras import backend as k
28-
# from keras.datasets.cifar import load_batch
29-
# from keras.preprocessing import image
30-
# from keras.utils import data_utils
31-
# from keras.utils.data_utils import get_file
3227
import numpy as np
3328

3429

@@ -125,10 +120,11 @@ def preprocess(x, y, nb_classes=10, max_value=255):
125120
def load_cifar10():
126121
"""Loads CIFAR10 dataset from config.CIFAR10_PATH or downloads it if necessary.
127122
128-
:return: (x_train, y_train), (x_test, y_test), min, max
129-
:rtype: (tuple of numpy.ndarray), (tuple of numpy.ndarray), float, float
123+
:return: `(x_train, y_train), (x_test, y_test), min, max`
124+
:rtype: `(np.ndarray, np.ndarray), (np.ndarray, np.ndarray), float, float`
130125
"""
131126
from config import CIFAR10_PATH
127+
import keras.backend as k
132128
from keras.datasets.cifar import load_batch
133129
from keras.utils.data_utils import get_file
134130

@@ -166,8 +162,8 @@ def load_cifar10():
166162
def load_mnist():
167163
"""Loads MNIST dataset from config.MNIST_PATH or downloads it if necessary.
168164
169-
:return: (x_train, y_train), (x_test, y_test), min, max
170-
:rtype: tuple of numpy.ndarray), (tuple of numpy.ndarray), float, float
165+
:return: `(x_train, y_train), (x_test, y_test), min, max`
166+
:rtype: `(np.ndarray, np.ndarray), (np.ndarray, np.ndarray), float, float`
171167
"""
172168
from config import MNIST_PATH
173169
from keras.utils.data_utils import get_file
@@ -195,8 +191,8 @@ def load_mnist():
195191
def load_imagenet():
196192
"""Loads Imagenet dataset from config.IMAGENET_PATH
197193
198-
:return: (x_train, y_train), (x_test, y_test), min, max
199-
:rtype: tuple of numpy.ndarray), (tuple of numpy.ndarray), float, float
194+
:return: `(x_train, y_train), (x_test, y_test), min, max`
195+
:rtype: `(np.ndarray, np.ndarray), (np.ndarray, np.ndarray), float, float`
200196
"""
201197
from config import IMAGENET_PATH
202198
from keras.preprocessing import image
@@ -238,10 +234,13 @@ def load_imagenet():
238234
def load_stl():
239235
"""Loads the STL-10 dataset from config.STL10_PATH or downloads it if necessary.
240236
241-
:return: (x_train, y_train), (x_test, y_test), min, max
242-
:rtype: tuple of numpy.ndarray), (tuple of numpy.ndarray), float, float
237+
:return: `(x_train, y_train), (x_test, y_test), min, max`
238+
:rtype: `(np.ndarray, np.ndarray), (np.ndarray, np.ndarray), float, float`
243239
"""
240+
from os.path import join
241+
244242
from config import STL10_PATH
243+
import keras.backend as k
245244
from keras.utils.data_utils import get_file
246245

247246
min_, max_ = 0., 1.
@@ -250,23 +249,23 @@ def load_stl():
250249
path = get_file('stl10_binary', cache_subdir=STL10_PATH, untar=True,
251250
origin='https://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz')
252251

253-
with open(os.path.join(path, 'train_X.bin'), 'rb') as f:
252+
with open(join(path, str('train_X.bin')), str('rb')) as f:
254253
x_train = np.fromfile(f, dtype=np.uint8)
255254
x_train = np.reshape(x_train, (-1, 3, 96, 96))
256255

257-
with open(os.path.join(path, 'test_X.bin'), 'rb') as f:
256+
with open(join(path, str('test_X.bin')), str('rb')) as f:
258257
x_test = np.fromfile(f, dtype=np.uint8)
259258
x_test = np.reshape(x_test, (-1, 3, 96, 96))
260259

261260
if k.image_data_format() == 'channels_last':
262261
x_train = x_train.transpose(0, 2, 3, 1)
263262
x_test = x_test.transpose(0, 2, 3, 1)
264263

265-
with open(os.path.join(path, 'train_y.bin'), 'rb') as f:
264+
with open(join(path, str('train_y.bin')), str('rb')) as f:
266265
y_train = np.fromfile(f, dtype=np.uint8)
267266
y_train -= 1
268267

269-
with open(os.path.join(path, 'test_y.bin'), 'rb') as f:
268+
with open(join(path, str('test_y.bin')), str('rb')) as f:
270269
y_test = np.fromfile(f, dtype=np.uint8)
271270
y_test -= 1
272271

@@ -282,8 +281,8 @@ def load_dataset(name):
282281
283282
:param name: Name of the dataset
284283
:type name: `str`
285-
:return: The dataset separated in training and test sets as `(x_train, y_train), (x_test, y_test)`
286-
:rtype: `tuple`
284+
:return: The dataset separated in training and test sets as `(x_train, y_train), (x_test, y_test), min, max`
285+
:rtype: `(np.ndarray, np.ndarray), (np.ndarray, np.ndarray), float, float`
287286
:raises NotImplementedError: If the dataset is unknown.
288287
"""
289288

0 commit comments

Comments
 (0)