@@ -76,14 +76,14 @@ def get_batch(self, n, only_faces=False, p_pos=None, p_neg=None, return_inds=Fal
76
76
selected_inds = np .concatenate ((selected_pos_inds , selected_neg_inds ))
77
77
78
78
sorted_inds = np .sort (selected_inds )
79
- train_img = self .images [sorted_inds ,:,:,::- 1 ]/ 255.
79
+ train_img = ( self .images [sorted_inds ,:,:,::- 1 ]/ 255. ). astype ( np . float32 )
80
80
train_label = self .labels [sorted_inds ,...]
81
81
return (train_img , train_label , sorted_inds ) if return_inds else (train_img , train_label )
82
82
83
83
def get_n_most_prob_faces (self , prob , n ):
84
84
idx = np .argsort (prob )[::- 1 ]
85
85
most_prob_inds = self .pos_train_inds [idx [:10 * n :10 ]]
86
- return self .images [most_prob_inds ,...]/ 255.
86
+ return ( self .images [most_prob_inds ,...]/ 255. ). astype ( np . float32 )
87
87
88
88
def get_all_train_faces (self ):
89
89
return self .images [ self .pos_train_inds ]
@@ -93,7 +93,7 @@ def get_test_faces():
93
93
cwd = os .path .dirname (__file__ )
94
94
f = h5py .File (os .path .join (cwd , "data" , "test_faces.h5py" ), "r" )
95
95
def get (key ):
96
- return f [key ][:][:,:,:,::- 1 ]/ 255.
96
+ return ( f [key ][:][:,:,:,::- 1 ]/ 255. ). astype ( np . float32 )
97
97
return get ("LM" ), get ("LF" ), get ("DM" ), get ("DF" )
98
98
99
99
0 commit comments