Skip to content

Commit 413be60

Browse files
committed
images to float32
1 parent 112e04f commit 413be60

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

mitdeeplearning/lab2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,14 @@ def get_batch(self, n, only_faces=False, p_pos=None, p_neg=None, return_inds=Fal
7676
selected_inds = np.concatenate((selected_pos_inds, selected_neg_inds))
7777

7878
sorted_inds = np.sort(selected_inds)
79-
train_img = self.images[sorted_inds,:,:,::-1]/255.
79+
train_img = (self.images[sorted_inds,:,:,::-1]/255.).astype(np.float32)
8080
train_label = self.labels[sorted_inds,...]
8181
return (train_img, train_label, sorted_inds) if return_inds else (train_img, train_label)
8282

8383
def get_n_most_prob_faces(self, prob, n):
8484
idx = np.argsort(prob)[::-1]
8585
most_prob_inds = self.pos_train_inds[idx[:10*n:10]]
86-
return self.images[most_prob_inds,...]/255.
86+
return (self.images[most_prob_inds,...]/255.).astype(np.float32)
8787

8888
def get_all_train_faces(self):
8989
return self.images[ self.pos_train_inds ]
@@ -93,7 +93,7 @@ def get_test_faces():
9393
cwd = os.path.dirname(__file__)
9494
f = h5py.File(os.path.join(cwd, "data", "test_faces.h5py"), "r")
9595
def get(key):
96-
return f[key][:][:,:,:,::-1]/255.
96+
return (f[key][:][:,:,:,::-1]/255.).astype(np.float32)
9797
return get("LM"), get("LF"), get("DM"), get("DF")
9898

9999

0 commit comments

Comments
 (0)