@@ -75,6 +75,7 @@ def __init__(self, data_path, batch_size, training=True):
75
75
76
76
self .train_inds = np .concatenate ((self .pos_train_inds , self .neg_train_inds ))
77
77
self .batch_size = batch_size
78
+ self .p_pos = np .ones (self .pos_train_inds .shape )/ len (self .pos_train_inds )
78
79
79
80
def get_train_size (self ):
80
81
return self .pos_train_inds .shape [0 ] + self .neg_train_inds .shape [0 ]
@@ -84,7 +85,7 @@ def __len__(self):
84
85
85
86
def __getitem__ (self , index ):
86
87
selected_pos_inds = np .random .choice (
87
- self .pos_train_inds , size = self .batch_size // 2 , replace = False
88
+ self .pos_train_inds , size = self .batch_size // 2 , replace = False , p = self . p_pos
88
89
)
89
90
selected_neg_inds = np .random .choice (
90
91
self .neg_train_inds , size = self .batch_size // 2 , replace = False
@@ -94,8 +95,7 @@ def __getitem__(self, index):
94
95
sorted_inds = np .sort (selected_inds )
95
96
train_img = (self .images [sorted_inds ] / 255.0 ).astype (np .float32 )
96
97
train_label = self .labels [sorted_inds , ...]
97
- inds = np .random .permutation (np .arange (len (train_img )))
98
- return np .array (train_img [inds ]), np .array (train_label [inds ])
98
+ return np .array (train_img ), np .array (train_label )
99
99
100
100
def get_n_most_prob_faces (self , prob , n ):
101
101
idx = np .argsort (prob )[::- 1 ]
@@ -121,7 +121,7 @@ def get_test_faces():
121
121
return images ["LF" ], images ["LM" ], images ["DF" ], images ["DM" ]
122
122
123
123
124
- def plot_k (imgs ):
124
+ def plot_k (imgs , fname = None ):
125
125
fig = plt .figure ()
126
126
fig .subplots_adjust (hspace = 0.6 )
127
127
num_images = len (imgs )
@@ -133,10 +133,12 @@ def plot_k(imgs):
133
133
ax .imshow (img_to_show , interpolation = "nearest" )
134
134
plt .subplots_adjust (wspace = 0.20 , hspace = 0.20 )
135
135
plt .show ()
136
+ if fname :
137
+ plt .savefig (fname )
136
138
plt .clf ()
137
139
138
140
139
- def plot_percentile (imgs ):
141
+ def plot_percentile (imgs , fname = None ):
140
142
fig = plt .figure ()
141
143
fig , axs = plt .subplots (1 , len (imgs ), figsize = (11 , 8 ))
142
144
for img in range (len (imgs )):
@@ -145,3 +147,26 @@ def plot_percentile(imgs):
145
147
ax .yaxis .set_visible (False )
146
148
img_to_show = imgs [img ]
147
149
ax .imshow (img_to_show , interpolation = "nearest" )
150
+ if fname :
151
+ plt .savefig (fname )
152
+
153
+ def plot_accuracy_vs_risk (sorted_images , sorted_uncertainty , sorted_preds , plot_title ):
154
+ num_percentile_intervals = 10
155
+ num_samples = len (sorted_images ) // num_percentile_intervals
156
+ all_imgs = []
157
+ all_unc = []
158
+ all_acc = []
159
+ for percentile in range (num_percentile_intervals ):
160
+ cur_imgs = sorted_images [percentile * num_samples : (percentile + 1 ) * num_samples ]
161
+ cur_unc = sorted_uncertainty [percentile * num_samples : (percentile + 1 ) * num_samples ]
162
+ cur_predictions = tf .nn .sigmoid (sorted_preds [percentile * num_samples : (percentile + 1 ) * num_samples ])
163
+ avged_imgs = tf .reduce_mean (cur_imgs , axis = 0 )
164
+ all_imgs .append (avged_imgs )
165
+ all_unc .append (tf .reduce_mean (cur_unc ))
166
+ all_acc .append ((np .ones ((num_samples )) == np .rint (cur_predictions )).mean ())
167
+
168
+ plt .plot (np .arange (num_percentile_intervals ) * 10 , all_acc )
169
+ plt .title (plot_title )
170
+ plt .show ()
171
+ plt .clf ()
172
+ return all_imgs
0 commit comments